Skip to content

Commit 1173353

Browse files
committed
fix fgw alpha parameter implementation
1 parent 0baf83b commit 1173353

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ot/gromov.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,11 @@ def df(G):
493493
return gwggrad(constC, hC1, hC2, G)
494494

495495
if log:
496-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
496+
res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
497497
log['fgw_dist'] = log['loss'][::-1][0]
498498
return res, log
499499
else:
500-
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
500+
return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
501501

502502

503503
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -573,7 +573,7 @@ def f(G):
573573
def df(G):
574574
return gwggrad(constC, hC1, hC2, G)
575575

576-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
576+
res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
577577
if log:
578578
log['fgw_dist'] = log['loss'][::-1][0]
579579
log['T'] = res
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10821082
T_temp = [t.T for t in T]
10831083
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
10841084

1085-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
1085+
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
10861086
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10871087

10881088
# T is N,ns

0 commit comments

Comments
 (0)