Skip to content

Commit 2a32e2e

Browse files
committed
fix log bug in gromov_wasserstein2
1 parent 65ca6bf commit 2a32e2e

File tree

2 files changed

+81
-79
lines changed

2 files changed

+81
-79
lines changed

ot/gromov.py

Lines changed: 77 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
276276
- p : distribution in the source space
277277
- q : distribution in the target space
278278
- L : loss function to account for the misfit between the similarity matrices
279-
- H : entropy
280279
281280
Parameters
282281
----------
@@ -343,6 +342,83 @@ def df(G):
343342
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
344343

345344

345+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
346+
"""
347+
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
348+
349+
The function solves the following optimization problem:
350+
351+
.. math::
352+
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
353+
354+
Where :
355+
- C1 : Metric cost matrix in the source space
356+
- C2 : Metric cost matrix in the target space
357+
- p : distribution in the source space
358+
- q : distribution in the target space
359+
- L : loss function to account for the misfit between the similarity matrices
360+
361+
Parameters
362+
----------
363+
C1 : ndarray, shape (ns, ns)
364+
Metric cost matrix in the source space
365+
C2 : ndarray, shape (nt, nt)
366+
Metric cost matrix in the target space
367+
p : ndarray, shape (ns,)
368+
Distribution in the source space.
369+
q : ndarray, shape (nt,)
370+
Distribution in the target space.
371+
loss_fun : str
372+
loss function used for the solver either 'square_loss' or 'kl_loss'
373+
max_iter : int, optional
374+
Max number of iterations
375+
tol : float, optional
376+
Stop threshold on error (>0)
377+
verbose : bool, optional
378+
Print information along iterations
379+
log : bool, optional
380+
record log if True
381+
armijo : bool, optional
382+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
383+
If there is convergence issues use False.
384+
385+
Returns
386+
-------
387+
gw_dist : float
388+
Gromov-Wasserstein distance
389+
log : dict
390+
convergence information and Coupling marix
391+
392+
References
393+
----------
394+
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
395+
"Gromov-Wasserstein averaging of kernel and distance matrices."
396+
International Conference on Machine Learning (ICML). 2016.
397+
398+
.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
399+
metric approach to object matching. Foundations of computational
400+
mathematics 11.4 (2011): 417-487.
401+
402+
"""
403+
404+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
405+
406+
G0 = p[:, None] * q[None, :]
407+
408+
def f(G):
409+
return gwloss(constC, hC1, hC2, G)
410+
411+
def df(G):
412+
return gwggrad(constC, hC1, hC2, G)
413+
res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
414+
log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
415+
log_gw['T'] = res
416+
if log:
417+
return log_gw['gw_dist'], log_gw
418+
else:
419+
return log_gw['gw_dist']
420+
421+
346422
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
347423
"""
348424
Computes the FGW transport between two graphs see [24]
@@ -506,84 +582,6 @@ def df(G):
506582
return log['fgw_dist']
507583

508584

509-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
510-
"""
511-
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
512-
513-
The function solves the following optimization problem:
514-
515-
.. math::
516-
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
517-
518-
Where :
519-
- C1 : Metric cost matrix in the source space
520-
- C2 : Metric cost matrix in the target space
521-
- p : distribution in the source space
522-
- q : distribution in the target space
523-
- L : loss function to account for the misfit between the similarity matrices
524-
- H : entropy
525-
526-
Parameters
527-
----------
528-
C1 : ndarray, shape (ns, ns)
529-
Metric cost matrix in the source space
530-
C2 : ndarray, shape (nt, nt)
531-
Metric cost matrix in the target space
532-
p : ndarray, shape (ns,)
533-
Distribution in the source space.
534-
q : ndarray, shape (nt,)
535-
Distribution in the target space.
536-
loss_fun : str
537-
loss function used for the solver either 'square_loss' or 'kl_loss'
538-
max_iter : int, optional
539-
Max number of iterations
540-
tol : float, optional
541-
Stop threshold on error (>0)
542-
verbose : bool, optional
543-
Print information along iterations
544-
log : bool, optional
545-
record log if True
546-
armijo : bool, optional
547-
If True the steps of the line-search is found via an armijo research. Else closed form is used.
548-
If there is convergence issues use False.
549-
550-
Returns
551-
-------
552-
gw_dist : float
553-
Gromov-Wasserstein distance
554-
log : dict
555-
convergence information and Coupling marix
556-
557-
References
558-
----------
559-
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
560-
"Gromov-Wasserstein averaging of kernel and distance matrices."
561-
International Conference on Machine Learning (ICML). 2016.
562-
563-
.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
564-
metric approach to object matching. Foundations of computational
565-
mathematics 11.4 (2011): 417-487.
566-
567-
"""
568-
569-
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
570-
571-
G0 = p[:, None] * q[None, :]
572-
573-
def f(G):
574-
return gwloss(constC, hC1, hC2, G)
575-
576-
def df(G):
577-
return gwggrad(constC, hC1, hC2, G)
578-
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
579-
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
580-
log['T'] = res
581-
if log:
582-
return log['gw_dist'], log
583-
else:
584-
return log['gw_dist']
585-
586-
587585
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
588586
max_iter=1000, tol=1e-9, verbose=False, log=False):
589587
"""

test/test_gromov.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@ def test_gromov():
4444

4545
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
4646

47+
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
48+
4749
G = log['T']
4850

4951
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
5052

53+
np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
54+
5155
# check constratints
5256
np.testing.assert_allclose(
5357
p, G.sum(1), atol=1e-04) # cf convergence gromov

0 commit comments

Comments
 (0)