1010# Nicolas Courty <ncourty@irisa.fr>
1111# Rémi Flamary <remi.flamary@unice.fr>
1212# Titouan Vayer <titouan.vayer@irisa.fr>
13+ #
1314# License: MIT License
1415
1516import numpy as np
@@ -351,9 +352,9 @@ def df(G):
351352 return cg (p , q , 0 , 1 , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
352353
353354
354- def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , ** kwargs ):
355+ def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
355356 """
356- Computes the FGW distance between two graphs see [3 ]
357+ Computes the FGW transport between two graphs see [24 ]
357358 .. math::
358359 \gamma = arg\min_\gamma (1-\a lpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
359360 s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
377378 distribution in the source space
378379 q : ndarray, shape (nt,)
379380 distribution in the target space
380- loss_fun : string,optionnal
381+ loss_fun : string,optional
381382 loss function used for the solver
382383 max_iter : int, optional
383384 Max number of iterations
@@ -416,7 +417,86 @@ def f(G):
416417 def df (G ):
417418 return gwggrad (constC , hC1 , hC2 , G )
418419
419- return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
420+ if log :
421+ res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
422+ log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
423+ return res , log
424+ else :
425+ return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
426+
427+
428+ def fused_gromov_wasserstein2 (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
429+ """
430+ Computes the FGW distance between two graphs see [24]
431+ .. math::
432+ \gamma = arg\min_\gamma (1-\a lpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
433+ s.t. \gamma 1 = p
434+ \gamma^T 1= q
435+ \gamma\geq 0
436+ where :
437+ - M is the (ns,nt) metric cost matrix
438+ - :math:`f` is the regularization term ( and df is its gradient)
439+ - a and b are source and target weights (sum to 1)
440+ - L is a loss function to account for the misfit between the similarity matrices
441+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
442+ Parameters
443+ ----------
444+ M : ndarray, shape (ns, nt)
445+ Metric cost matrix between features across domains
446+ C1 : ndarray, shape (ns, ns)
447+ Metric cost matrix respresentative of the structure in the source space
448+ C2 : ndarray, shape (nt, nt)
449+ Metric cost matrix espresentative of the structure in the target space
450+ p : ndarray, shape (ns,)
451+ distribution in the source space
452+ q : ndarray, shape (nt,)
453+ distribution in the target space
454+ loss_fun : string,optional
455+ loss function used for the solver
456+ max_iter : int, optional
457+ Max number of iterations
458+ tol : float, optional
459+ Stop threshold on error (>0)
460+ verbose : bool, optional
461+ Print information along iterations
462+ log : bool, optional
463+ record log if True
464+ armijo : bool, optional
465+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
466+ If there is convergence issues use False.
467+ **kwargs : dict
468+ parameters can be directly pased to the ot.optim.cg solver
469+ Returns
470+ -------
471+ gamma : (ns x nt) ndarray
472+ Optimal transportation matrix for the given parameters
473+ log : dict
474+ log dictionary return only if log==True in parameters
475+ References
476+ ----------
477+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\' e}mi, Tavenard Romain
478+ and Courty Nicolas
479+ "Optimal Transport for structured data with application on graphs"
480+ International Conference on Machine Learning (ICML). 2019.
481+ """
482+
483+ constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
484+
485+ G0 = p [:, None ] * q [None , :]
486+
487+ def f (G ):
488+ return gwloss (constC , hC1 , hC2 , G )
489+
490+ def df (G ):
491+ return gwggrad (constC , hC1 , hC2 , G )
492+
493+ res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
494+ if log :
495+ log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
496+ log ['T' ] = res
497+ return log ['fgw_dist' ], log
498+ else :
499+ return log ['fgw_dist' ]
420500
421501
422502def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
889969
890970def fgw_barycenters (N , Ys , Cs , ps , lambdas , alpha , fixed_structure = False , fixed_features = False ,
891971 p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-9 ,
892- verbose = False , log = True , init_C = None , init_X = None ):
972+ verbose = False , log = False , init_C = None , init_X = None ):
893973 """
894974 Compute the fgw barycenter as presented eq (5) in [24].
895975 ----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
919999 Barycenters' features
9201000 C : ndarray, shape (N,N)
9211001 Barycenters' structure matrix
922- log_:
1002+ log_: dictionary
1003+ Only returned when log=True
9231004 T : list of (N,ns) transport matrices
9241005 Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
9251006 References
@@ -1015,14 +1096,13 @@ class UndefinedParameter(Exception):
10151096 T = [fused_gromov_wasserstein ((1 - alpha ) * Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha , numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
10161097
10171098 # T is N,ns
1018-
1019- log_ ['Ts_iter' ].append (T )
10201099 err_feature = np .linalg .norm (X - Xprev .reshape (N , d ))
10211100 err_structure = np .linalg .norm (C - Cprev )
10221101
10231102 if log :
10241103 log_ ['err_feature' ].append (err_feature )
10251104 log_ ['err_structure' ].append (err_structure )
1105+ log_ ['Ts_iter' ].append (T )
10261106
10271107 if verbose :
10281108 if cpt % 200 == 0 :
@@ -1032,11 +1112,15 @@ class UndefinedParameter(Exception):
10321112 print ('{:5d}|{:8e}|' .format (cpt , err_feature ))
10331113
10341114 cpt += 1
1035- log_ ['T' ] = T # from target to Ys
1036- log_ ['p' ] = p
1037- log_ ['Ms' ] = Ms # Ms are N,ns
1115+ if log :
1116+ log_ ['T' ] = T # from target to Ys
1117+ log_ ['p' ] = p
1118+ log_ ['Ms' ] = Ms # Ms are N,ns
10381119
1039- return X , C , log_
1120+ if log :
1121+ return X , C , log_
1122+ else :
1123+ return X , C
10401124
10411125
10421126def update_sructure_matrix (p , lambdas , T , Cs ):
0 commit comments