@@ -918,7 +918,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
918918 else :
919919 return geometricBar (weights , UKv )
920920
921- def convolutional_barycenter2d (A ,reg ,weights = None ,numItermax = 10000 , stopThr = 1e-9 , verbose = False , log = False ):
921+
922+ def convolutional_barycenter2d (A , reg , weights = None , numItermax = 10000 , stopThr = 1e-9 , verbose = False , log = False ):
922923 """Compute the entropic regularized wasserstein barycenter of distributions A
923924 where A is a collection of 2D images.
924925
@@ -979,51 +980,52 @@ def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e
979980 if log :
980981 log = {'err' : []}
981982
982- b = np .zeros_like (A [0 ,:, :])
983- U = np .ones_like (A )
984- KV = np .ones_like (A )
985- threshold = 1e-30 # in order to avoids numerical precision issues
983+ b = np .zeros_like (A [0 , :, :])
984+ U = np .ones_like (A )
985+ KV = np .ones_like (A )
986+ threshold = 1e-30 # in order to avoids numerical precision issues
986987
987988 cpt = 0
988- err = 1
989-
990- # build the convolution operator
991- t = np .linspace (0 ,1 ,A .shape [1 ])
992- [Y ,X ] = np .meshgrid (t ,t )
993- xi1 = np .exp (- (X - Y )** 2 / reg )
994- K = lambda x : np .dot (np .dot (xi1 ,x ),xi1 )
995-
996- while (err > stopThr and cpt < numItermax ):
997-
998- bold = b
999- cpt = cpt + 1
1000-
1001- b = np .zeros_like (A [0 ,:,:])
989+ err = 1
990+
991+ # build the convolution operator
992+ t = np .linspace (0 , 1 , A .shape [1 ])
993+ [Y , X ] = np .meshgrid (t , t )
994+ xi1 = np .exp (- (X - Y )** 2 / reg )
995+
996+ def K (x ): return np .dot (np .dot (xi1 , x ), xi1 )
997+
998+ while (err > stopThr and cpt < numItermax ):
999+
1000+ bold = b
1001+ cpt = cpt + 1
1002+
1003+ b = np .zeros_like (A [0 , :, :])
10021004 for r in range (A .shape [0 ]):
1003- KV [r ,:,:] = K (A [r ,:,:] / np .maximum (threshold ,K (U [r ,:, :])))
1004- b += weights [r ] * np .log (np .maximum (threshold , U [r ,:,:] * KV [r ,:, :]))
1005+ KV [r , :, :] = K (A [r , :, :] / np .maximum (threshold , K (U [r , :, :])))
1006+ b += weights [r ] * np .log (np .maximum (threshold , U [r , :, :] * KV [r , :, :]))
10051007 b = np .exp (b )
10061008 for r in range (A .shape [0 ]):
1007- U [r ,:,:] = b / np .maximum (threshold ,KV [r ,:, :])
1008-
1009- if cpt % 10 == 1 :
1010- err = np .sum (np .abs (bold - b ))
1009+ U [r , :, :] = b / np .maximum (threshold , KV [r , :, :])
1010+
1011+ if cpt % 10 == 1 :
1012+ err = np .sum (np .abs (bold - b ))
10111013 # log and verbose print
10121014 if log :
10131015 log ['err' ].append (err )
10141016
10151017 if verbose :
1016- if cpt % 200 == 0 :
1017- print ('{:5s}|{:12s}' .format ('It.' ,'Err' )+ '\n ' + '-' * 19 )
1018- print ('{:5d}|{:8e}|' .format (cpt ,err ))
1018+ if cpt % 200 == 0 :
1019+ print ('{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
1020+ print ('{:5d}|{:8e}|' .format (cpt , err ))
10191021
10201022 if log :
1021- log ['niter' ]= cpt
1022- log ['U' ]= U
1023- return b ,log
1023+ log ['niter' ] = cpt
1024+ log ['U' ] = U
1025+ return b , log
10241026 else :
1025- return b
1026-
1027+ return b
1028+
10271029
10281030def unmix (a , D , M , M0 , h0 , reg , reg0 , alpha , numItermax = 1000 ,
10291031 stopThr = 1e-3 , verbose = False , log = False ):
0 commit comments