@@ -493,11 +493,11 @@ def df(G):
493493 return gwggrad (constC , hC1 , hC2 , G )
494494
495495 if log :
496- res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
496+ res , log = cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
497497 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
498498 return res , log
499499 else :
500- return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
500+ return cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
501501
502502
503503def fused_gromov_wasserstein2 (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
@@ -573,7 +573,7 @@ def f(G):
573573 def df (G ):
574574 return gwggrad (constC , hC1 , hC2 , G )
575575
576- res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
576+ res , log = cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
577577 if log :
578578 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
579579 log ['T' ] = res
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10821082 T_temp = [t .T for t in T ]
10831083 C = update_sructure_matrix (p , lambdas , T_temp , Cs )
10841084
1085- T = [fused_gromov_wasserstein (( 1 - alpha ) * Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha ,
1085+ T = [fused_gromov_wasserstein (Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha ,
10861086 numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
10871087
10881088 # T is N,ns
0 commit comments