@@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
276276 - p : distribution in the source space
277277 - q : distribution in the target space
278278 - L : loss function to account for the misfit between the similarity matrices
279- - H : entropy
280279
281280 Parameters
282281 ----------
@@ -343,6 +342,83 @@ def df(G):
343342 return cg (p , q , 0 , 1 , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
344343
345344
345+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
346+ """
347+ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
348+
349+ The function solves the following optimization problem:
350+
351+ .. math::
352+ GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
353+
354+ Where :
355+ - C1 : Metric cost matrix in the source space
356+ - C2 : Metric cost matrix in the target space
357+ - p : distribution in the source space
358+ - q : distribution in the target space
359+ - L : loss function to account for the misfit between the similarity matrices
360+
361+ Parameters
362+ ----------
363+ C1 : ndarray, shape (ns, ns)
364+ Metric cost matrix in the source space
365+ C2 : ndarray, shape (nt, nt)
366+ Metric cost matrix in the target space
367+ p : ndarray, shape (ns,)
368+ Distribution in the source space.
369+ q : ndarray, shape (nt,)
370+ Distribution in the target space.
371+ loss_fun : str
372+ loss function used for the solver either 'square_loss' or 'kl_loss'
373+ max_iter : int, optional
374+ Max number of iterations
375+ tol : float, optional
376+ Stop threshold on error (>0)
377+ verbose : bool, optional
378+ Print information along iterations
379+ log : bool, optional
380+ record log if True
381+ armijo : bool, optional
382+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
383+ If there is convergence issues use False.
384+
385+ Returns
386+ -------
387+ gw_dist : float
388+ Gromov-Wasserstein distance
389+ log : dict
390+ convergence information and Coupling marix
391+
392+ References
393+ ----------
394+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
395+ "Gromov-Wasserstein averaging of kernel and distance matrices."
396+ International Conference on Machine Learning (ICML). 2016.
397+
398+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
399+ metric approach to object matching. Foundations of computational
400+ mathematics 11.4 (2011): 417-487.
401+
402+ """
403+
404+ constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
405+
406+ G0 = p [:, None ] * q [None , :]
407+
408+ def f (G ):
409+ return gwloss (constC , hC1 , hC2 , G )
410+
411+ def df (G ):
412+ return gwggrad (constC , hC1 , hC2 , G )
413+ res , log_gw = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
414+ log_gw ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
415+ log_gw ['T' ] = res
416+ if log :
417+ return log_gw ['gw_dist' ], log_gw
418+ else :
419+ return log_gw ['gw_dist' ]
420+
421+
346422def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
347423 """
348424 Computes the FGW transport between two graphs see [24]
@@ -506,84 +582,6 @@ def df(G):
506582 return log ['fgw_dist' ]
507583
508584
509- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
510- """
511- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
512-
513- The function solves the following optimization problem:
514-
515- .. math::
516- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
517-
518- Where :
519- - C1 : Metric cost matrix in the source space
520- - C2 : Metric cost matrix in the target space
521- - p : distribution in the source space
522- - q : distribution in the target space
523- - L : loss function to account for the misfit between the similarity matrices
524- - H : entropy
525-
526- Parameters
527- ----------
528- C1 : ndarray, shape (ns, ns)
529- Metric cost matrix in the source space
530- C2 : ndarray, shape (nt, nt)
531- Metric cost matrix in the target space
532- p : ndarray, shape (ns,)
533- Distribution in the source space.
534- q : ndarray, shape (nt,)
535- Distribution in the target space.
536- loss_fun : str
537- loss function used for the solver either 'square_loss' or 'kl_loss'
538- max_iter : int, optional
539- Max number of iterations
540- tol : float, optional
541- Stop threshold on error (>0)
542- verbose : bool, optional
543- Print information along iterations
544- log : bool, optional
545- record log if True
546- armijo : bool, optional
547- If True the steps of the line-search is found via an armijo research. Else closed form is used.
548- If there is convergence issues use False.
549-
550- Returns
551- -------
552- gw_dist : float
553- Gromov-Wasserstein distance
554- log : dict
555- convergence information and Coupling marix
556-
557- References
558- ----------
559- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
560- "Gromov-Wasserstein averaging of kernel and distance matrices."
561- International Conference on Machine Learning (ICML). 2016.
562-
563- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
564- metric approach to object matching. Foundations of computational
565- mathematics 11.4 (2011): 417-487.
566-
567- """
568-
569- constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
570-
571- G0 = p [:, None ] * q [None , :]
572-
573- def f (G ):
574- return gwloss (constC , hC1 , hC2 , G )
575-
576- def df (G ):
577- return gwggrad (constC , hC1 , hC2 , G )
578- res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
579- log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
580- log ['T' ] = res
581- if log :
582- return log ['gw_dist' ], log
583- else :
584- return log ['gw_dist' ]
585-
586-
587585def entropic_gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon ,
588586 max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
589587 """
0 commit comments