@@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
433433
434434 where :
435435 - M is the (ns,nt) metric cost matrix
436- - :math:`f` is the regularization term ( and df is its gradient)
437- - a and b are source and target weights (sum to 1)
436+ - p and q are source and target weights (sum to 1)
438437 - L is a loss function to account for the misfit between the similarity matrices
439438
440439 The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
453452 Distribution in the target space
454453 loss_fun : str, optional
455454 Loss function used for the solver
456- max_iter : int, optional
457- Max number of iterations
458- tol : float, optional
459- Stop threshold on error (>0)
460- verbose : bool, optional
461- Print information along iterations
462- log : bool, optional
463- record log if True
455+ alpha : float, optional
456+ Trade-off parameter (0 < alpha < 1)
464457 armijo : bool, optional
465458 If True the steps of the line-search is found via an armijo research. Else closed form is used.
466459 If there is convergence issues use False.
460+ log : bool, optional
461+ record log if True
467462 **kwargs : dict
468463 parameters can be directly passed to the ot.optim.cg solver
469464
@@ -493,11 +488,11 @@ def df(G):
493488 return gwggrad (constC , hC1 , hC2 , G )
494489
495490 if log :
496- res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
491+ res , log = cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
497492 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
498493 return res , log
499494 else :
500- return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
495+ return cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
501496
502497
503498def fused_gromov_wasserstein2 (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
@@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
515510
516511 where :
517512 - M is the (ns,nt) metric cost matrix
518- - :math:`f` is the regularization term ( and df is its gradient)
519- - a and b are source and target weights (sum to 1)
513+ - p and q are source and target weights (sum to 1)
520514 - L is a loss function to account for the misfit between the similarity matrices
521515 The algorithm used for solving the problem is conditional gradient as discussed in [1]_
522516
@@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
534528 Distribution in the target space.
535529 loss_fun : str, optional
536530 Loss function used for the solver.
537- max_iter : int, optional
538- Max number of iterations
539- tol : float, optional
540- Stop threshold on error (>0)
541- verbose : bool, optional
542- Print information along iterations
543- log : bool, optional
544- Record log if True.
531+ alpha : float, optional
532+ Trade-off parameter (0 < alpha < 1)
545533 armijo : bool, optional
546534 If True the steps of the line-search is found via an armijo research.
547535 Else closed form is used. If there is convergence issues use False.
536+ log : bool, optional
537+ Record log if True.
548538 **kwargs : dict
549539 Parameters can be directly pased to the ot.optim.cg solver.
550540
@@ -573,7 +563,7 @@ def f(G):
573563 def df (G ):
574564 return gwggrad (constC , hC1 , hC2 , G )
575565
576- res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
566+ res , log = cg (p , q , ( 1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
577567 if log :
578568 log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
579569 log ['T' ] = res
@@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
994984 Whether to fix the structure of the barycenter during the updates
995985 fixed_features : bool
996986 Whether to fix the feature of the barycenter during the updates
987+ loss_fun : str
988+ Loss function used for the solver either 'square_loss' or 'kl_loss'
989+ max_iter : int, optional
990+ Max number of iterations
991+ tol : float, optional
992+ Stop threshol on error (>0).
993+ verbose : bool, optional
994+ Print information along iterations.
995+ log : bool, optional
996+ Record log if True.
997997 init_C : ndarray, shape (N,N), optional
998998 Initialization for the barycenters' structure matrix. If not set
999999 a random init is used.
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10821082 T_temp = [t .T for t in T ]
10831083 C = update_sructure_matrix (p , lambdas , T_temp , Cs )
10841084
1085- T = [fused_gromov_wasserstein (( 1 - alpha ) * Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha ,
1085+ T = [fused_gromov_wasserstein (Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha ,
10861086 numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
10871087
10881088 # T is N,ns
0 commit comments