@@ -488,11 +488,11 @@ def df(G):
488488 return gwggrad (constC , hC1 , hC2 , G )
489489
490490 if log :
491- res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
491+ res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
492492 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
493493 return res , log
494494 else :
495- return cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
495+ return cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
496496
497497
498498def fused_gromov_wasserstein2 (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
@@ -563,7 +563,7 @@ def f(G):
563563 def df (G ):
564564 return gwggrad (constC , hC1 , hC2 , G )
565565
566- res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
566+ res , log = cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
567567 if log :
568568 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
569569 log ['T' ] = res
@@ -987,13 +987,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
987987 loss_fun : str
988988 Loss function used for the solver either 'square_loss' or 'kl_loss'
989989 max_iter : int, optional
990- Max number of iterations
990+ Max number of iterations
991991 tol : float, optional
992992 Stop threshol on error (>0).
993993 verbose : bool, optional
994994 Print information along iterations.
995995 log : bool, optional
996- Record log if True.
996+ Record log if True.
997997 init_C : ndarray, shape (N,N), optional
998998 Initialization for the barycenters' structure matrix. If not set
999999 a random init is used.
0 commit comments