Skip to content

Commit fa06bb3

Browse files
authored
Merge pull request #133 from little-nem/fgw_fix
[MRG] Fix Fused Gromov Wasserstein implementation
2 parents 599154c + bb15cdd commit fa06bb3

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

ot/gromov.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
433433
434434
where :
435435
- M is the (ns,nt) metric cost matrix
436-
- :math:`f` is the regularization term ( and df is its gradient)
437-
- a and b are source and target weights (sum to 1)
436+
- p and q are source and target weights (sum to 1)
438437
- L is a loss function to account for the misfit between the similarity matrices
439438
440439
The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
453452
Distribution in the target space
454453
loss_fun : str, optional
455454
Loss function used for the solver
456-
max_iter : int, optional
457-
Max number of iterations
458-
tol : float, optional
459-
Stop threshold on error (>0)
460-
verbose : bool, optional
461-
Print information along iterations
462-
log : bool, optional
463-
record log if True
455+
alpha : float, optional
456+
Trade-off parameter (0 < alpha < 1)
464457
armijo : bool, optional
465458
If True the steps of the line-search is found via an armijo research. Else closed form is used.
466459
If there is convergence issues use False.
460+
log : bool, optional
461+
record log if True
467462
**kwargs : dict
468463
parameters can be directly passed to the ot.optim.cg solver
469464
@@ -493,11 +488,11 @@ def df(G):
493488
return gwggrad(constC, hC1, hC2, G)
494489

495490
if log:
496-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
491+
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
497492
log['fgw_dist'] = log['loss'][::-1][0]
498493
return res, log
499494
else:
500-
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
495+
return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
501496

502497

503498
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
515510
516511
where :
517512
- M is the (ns,nt) metric cost matrix
518-
- :math:`f` is the regularization term ( and df is its gradient)
519-
- a and b are source and target weights (sum to 1)
513+
- p and q are source and target weights (sum to 1)
520514
- L is a loss function to account for the misfit between the similarity matrices
521515
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
522516
@@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
534528
Distribution in the target space.
535529
loss_fun : str, optional
536530
Loss function used for the solver.
537-
max_iter : int, optional
538-
Max number of iterations
539-
tol : float, optional
540-
Stop threshold on error (>0)
541-
verbose : bool, optional
542-
Print information along iterations
543-
log : bool, optional
544-
Record log if True.
531+
alpha : float, optional
532+
Trade-off parameter (0 < alpha < 1)
545533
armijo : bool, optional
546534
If True the steps of the line-search is found via an armijo research.
547535
Else closed form is used. If there is convergence issues use False.
536+
log : bool, optional
537+
Record log if True.
548538
**kwargs : dict
549539
Parameters can be directly pased to the ot.optim.cg solver.
550540
@@ -573,7 +563,7 @@ def f(G):
573563
def df(G):
574564
return gwggrad(constC, hC1, hC2, G)
575565

576-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
566+
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
577567
if log:
578568
log['fgw_dist'] = log['loss'][::-1][0]
579569
log['T'] = res
@@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
994984
Whether to fix the structure of the barycenter during the updates
995985
fixed_features : bool
996986
Whether to fix the feature of the barycenter during the updates
987+
loss_fun : str
988+
Loss function used for the solver either 'square_loss' or 'kl_loss'
989+
max_iter : int, optional
990+
Max number of iterations
991+
tol : float, optional
992+
Stop threshol on error (>0).
993+
verbose : bool, optional
994+
Print information along iterations.
995+
log : bool, optional
996+
Record log if True.
997997
init_C : ndarray, shape (N,N), optional
998998
Initialization for the barycenters' structure matrix. If not set
999999
a random init is used.
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10821082
T_temp = [t.T for t in T]
10831083
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
10841084

1085-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
1085+
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
10861086
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10871087

10881088
# T is N,ns

0 commit comments

Comments
 (0)