@@ -33,12 +33,12 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
3333 * C2 : Metric cost matrix in the target space
3434 * T : A coupling between those two spaces
3535
36- The square-loss function L(a,b)=(1/2)* |a-b|^2 is read as :
36+ The square-loss function L(a,b)=|a-b|^2 is read as :
3737 L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
38- * f1(a)=(a^2)/2
39- * f2(b)=(b^2)/2
38+ * f1(a)=(a^2)
39+ * f2(b)=(b^2)
4040 * h1(a)=a
41- * h2(b)=b
41+ * h2(b)=2* b
4242
4343 The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
4444 L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
269269 return np .exp (np .divide (tmpsum , ppt ))
270270
271271
272- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , log = False , amijo = False , ** kwargs ):
272+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
273273 """
274274 Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
275275
@@ -307,8 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs)
307307 Print information along iterations
308308 log : bool, optional
309309 record log if True
310- amijo : bool, optional
311- If True the steps of the line-search is found via an amijo research. Else closed form is used.
310+ armijo : bool, optional
311+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
312312 If there is convergence issues use False.
313313 **kwargs : dict
314314 parameters can be directly pased to the ot.optim.cg solver
@@ -344,14 +344,14 @@ def df(G):
344344 return gwggrad (constC , hC1 , hC2 , G )
345345
346346 if log :
347- res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
347+ res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
348348 log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
349349 return res , log
350350 else :
351- return cg (p , q , 0 , 1 , f , df , G0 , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
351+ return cg (p , q , 0 , 1 , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
352352
353353
354- def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , amijo = False , ** kwargs ):
354+ def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , ** kwargs ):
355355 """
356356 Computes the FGW distance between two graphs see [3]
357357 .. math::
@@ -363,6 +363,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
363363 - M is the (ns,nt) metric cost matrix
364364 - :math:`f` is the regularization term ( and df is its gradient)
365365 - a and b are source and target weights (sum to 1)
366+ - L is a loss function to account for the misfit between the similarity matrices
366367 The algorithm used for solving the problem is conditional gradient as discussed in [1]_
367368 Parameters
368369 ----------
@@ -386,8 +387,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
386387 Print information along iterations
387388 log : bool, optional
388389 record log if True
389- amijo : bool, optional
390- If True the steps of the line-search is found via an amijo research. Else closed form is used.
390+ armijo : bool, optional
391+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
391392 If there is convergence issues use False.
392393 **kwargs : dict
393394 parameters can be directly pased to the ot.optim.cg solver
@@ -415,10 +416,10 @@ def f(G):
415416 def df (G ):
416417 return gwggrad (constC , hC1 , hC2 , G )
417418
418- return cg (p , q , M , alpha , f , df , G0 , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
419+ return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
419420
420421
421- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , amijo = False , ** kwargs ):
422+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
422423 """
423424 Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
424425
@@ -456,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs
456457 Print information along iterations
457458 log : bool, optional
458459 record log if True
459- amijo : bool, optional
460- If True the steps of the line-search is found via an amijo research. Else closed form is used.
460+ armijo : bool, optional
461+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
461462 If there is convergence issues use False.
462463 Returns
463464 -------
@@ -487,7 +488,7 @@ def f(G):
487488
488489 def df (G ):
489490 return gwggrad (constC , hC1 , hC2 , G )
490- res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
491+ res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
491492 log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
492493 log ['T' ] = res
493494 if log :
@@ -890,7 +891,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
890891 p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-9 ,
891892 verbose = False , log = True , init_C = None , init_X = None ):
892893 """
893- Compute the fgw barycenter as presented eq (5) in [3 ].
894+ Compute the fgw barycenter as presented eq (5) in [24 ].
894895 ----------
895896 N : integer
896897 Desired number of samples of the target barycenter
@@ -1065,7 +1066,7 @@ def update_sructure_matrix(p, lambdas, T, Cs):
10651066
10661067def update_feature_matrix (lambdas , Ys , Ts , p ):
10671068 """
1068- Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3 ]
1069+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24 ]
10691070 calculated at each iteration
10701071 Parameters
10711072 ----------
0 commit comments