1515import warnings
1616
1717from ..bregman import sinkhorn
18- from ..utils import dist , list_to_array , check_random_state , unif
18+ from ..utils import dist , UndefinedParameter , list_to_array , check_random_state , unif
1919from ..backend import get_backend
2020
2121from ._utils import init_matrix , gwloss , gwggrad
@@ -345,8 +345,9 @@ def entropic_gromov_wasserstein2(
345345
346346def entropic_gromov_barycenters (
347347 N , Cs , ps = None , p = None , lambdas = None , loss_fun = 'square_loss' ,
348- epsilon = 0.1 , symmetric = True , max_iter = 1000 , tol = 1e-9 , warmstartT = False ,
349- verbose = False , log = False , init_C = None , random_state = None , ** kwargs ):
348+ epsilon = 0.1 , symmetric = True , max_iter = 1000 , tol = 1e-9 ,
349+ stop_criterion = 'barycenter' , warmstartT = False , verbose = False ,
350+ log = False , init_C = None , random_state = None , ** kwargs ):
350351 r"""
351352 Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
352353 estimated using Gromov-Wasserstein transports from Sinkhorn projections.
@@ -388,6 +389,10 @@ def entropic_gromov_barycenters(
388389 Max number of iterations
389390 tol : float, optional
390391 Stop threshold on error (>0)
392+ stop_criterion : str, optional. Default is 'barycenter'.
393+ Convergence criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
394+ uses absolute norm variations of estimated barycenters. Else if set to 'loss'
395+ uses the relative variations of the loss.
391396 warmstartT: bool, optional
392397 Either to perform warmstart of transport plans in the successive
393398 gromov-wasserstein transport problems.
@@ -407,7 +412,11 @@ def entropic_gromov_barycenters(
407412 C : array-like, shape (`N`, `N`)
408413 Similarity matrix in the barycenter space (permutated arbitrarily)
409414 log : dict
410- Log dictionary of error during iterations. Return only if `log=True` in parameters.
415+ Only returned when log=True. It contains the keys:
416+
417+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
418+ - :math:`\mathbf{p}`: (`N`,) barycenter weights
419+ - values used in convergence evaluation.
411420
412421 References
413422 ----------
@@ -418,6 +427,9 @@ def entropic_gromov_barycenters(
418427 if loss_fun not in ('square_loss' , 'kl_loss' ):
419428 raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
420429
430+ if stop_criterion not in ['barycenter' , 'loss' ]:
431+ raise ValueError (f"Unknown `stop_criterion='{ stop_criterion } '`. Use one of: { 'barycenter' , 'loss' } ." )
432+
421433 Cs = list_to_array (* Cs )
422434 arr = [* Cs ]
423435 if ps is not None :
@@ -446,45 +458,75 @@ def entropic_gromov_barycenters(
446458 C = init_C
447459
448460 cpt = 0
449- err = 1
450-
451- error = []
461+ err = 1e15 # either the error on 'barycenter' or 'loss'
452462
453463 if warmstartT :
454464 T = [None ] * S
455465
466+ if stop_criterion == 'barycenter' :
467+ inner_log = False
468+ else :
469+ inner_log = True
470+ curr_loss = 1e15
471+
472+ if log :
473+ log_ = {}
474+ log_ ['err' ] = []
475+ if stop_criterion == 'loss' :
476+ log_ ['loss' ] = []
477+
456478 while (err > tol ) and (cpt < max_iter ):
457- Cprev = C
479+ if stop_criterion == 'barycenter' :
480+ Cprev = C
481+ else :
482+ prev_loss = curr_loss
483+
484+ # get transport plans
458485 if warmstartT :
459- T = [entropic_gromov_wasserstein (
486+ res = [entropic_gromov_wasserstein (
460487 C , Cs [s ], p , ps [s ], loss_fun , epsilon , symmetric , T [s ],
461- max_iter , 1e-4 , verbose = verbose , log = False , ** kwargs ) for s in range (S )]
488+ max_iter , 1e-4 , verbose = verbose , log = inner_log , ** kwargs ) for s in range (S )]
462489 else :
463- T = [entropic_gromov_wasserstein (
490+ res = [entropic_gromov_wasserstein (
464491 C , Cs [s ], p , ps [s ], loss_fun , epsilon , symmetric , None ,
465- max_iter , 1e-4 , verbose = verbose , log = False , ** kwargs ) for s in range (S )]
492+ max_iter , 1e-4 , verbose = verbose , log = inner_log , ** kwargs ) for s in range (S )]
493+ if stop_criterion == 'barycenter' :
494+ T = res
495+ else :
496+ T = [output [0 ] for output in res ]
497+ curr_loss = np .sum ([output [1 ]['gw_dist' ] for output in res ])
466498
499+ # update barycenters
467500 if loss_fun == 'square_loss' :
468501 C = update_square_loss (p , lambdas , T , Cs , nx )
469502 elif loss_fun == 'kl_loss' :
470503 C = update_kl_loss (p , lambdas , T , Cs , nx )
471504
472- if cpt % 10 == 0 :
473- # we can speed up the process by checking for the error only all
474- # the 10th iterations
505+ # update convergence criterion
506+ if stop_criterion == 'barycenter' :
475507 err = nx .norm (C - Cprev )
476- error .append (err )
508+ if log :
509+ log_ ['err' ].append (err )
477510
478- if verbose :
479- if cpt % 200 == 0 :
480- print ('{:5s}|{:12s}' .format (
481- 'It.' , 'Err' ) + '\n ' + '-' * 19 )
482- print ('{:5d}|{:8e}|' .format (cpt , err ))
511+ else :
512+ err = abs (curr_loss - prev_loss ) / prev_loss if prev_loss != 0. else np .nan
513+ if log :
514+ log_ ['loss' ].append (curr_loss )
515+ log_ ['err' ].append (err )
516+
517+ if verbose :
518+ if cpt % 200 == 0 :
519+ print ('{:5s}|{:12s}' .format (
520+ 'It.' , 'Err' ) + '\n ' + '-' * 19 )
521+ print ('{:5d}|{:8e}|' .format (cpt , err ))
483522
484523 cpt += 1
485524
486525 if log :
487- return C , {"err" : error }
526+ log_ ['T' ] = T
527+ log_ ['p' ] = p
528+
529+ return C , log_
488530 else :
489531 return C
490532
@@ -838,8 +880,9 @@ def entropic_fused_gromov_wasserstein2(
838880def entropic_fused_gromov_barycenters (
839881 N , Ys , Cs , ps = None , p = None , lambdas = None , loss_fun = 'square_loss' ,
840882 epsilon = 0.1 , symmetric = True , alpha = 0.5 , max_iter = 1000 , tol = 1e-9 ,
841- warmstartT = False , verbose = False , log = False , init_C = None , init_Y = None ,
842- random_state = None , ** kwargs ):
883+ stop_criterion = 'barycenter' , warmstartT = False , verbose = False ,
884+ log = False , init_C = None , init_Y = None , fixed_structure = False ,
885+ fixed_features = False , random_state = None , ** kwargs ):
843886 r"""
844887 Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}`
845888 estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections.
@@ -886,6 +929,10 @@ def entropic_fused_gromov_barycenters(
886929 Max number of iterations
887930 tol : float, optional
888931 Stop threshold on error (>0)
932+ stop_criterion : str, optional. Default is 'barycenter'.
933+ Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
934+ uses absolute norm variations of estimated barycenters. Else if set to 'loss'
935+ uses the relative variations of the loss.
889936 warmstartT: bool, optional
890937 Either to perform warmstart of transport plans in the successive
891938 fused gromov-wasserstein transport problems.
@@ -898,6 +945,10 @@ def entropic_fused_gromov_barycenters(
898945 init_Y : array-like, shape (N,d), optional
899946 Initialization for the barycenters' features. If not set a
900947 random init is used.
948+ fixed_structure : bool, optional
949+ Whether to fix the structure of the barycenter during the updates.
950+ fixed_features : bool, optional
951+ Whether to fix the feature of the barycenter during the updates
901952 random_state : int or RandomState instance, optional
902953 Fix the seed for reproducibility
903954 **kwargs: dict
@@ -910,7 +961,12 @@ def entropic_fused_gromov_barycenters(
910961 C : array-like, shape (`N`, `N`)
911962 Similarity matrix in the barycenter space (permutated as Y's rows)
912963 log : dict
913- Log dictionary of error during iterations. Return only if `log=True` in parameters.
964+ Only returned when log=True. It contains the keys:
965+
966+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
967+ - :math:`\mathbf{p}`: (`N`,) barycenter weights
968+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
969+ - values used in convergence evaluation.
914970
915971 References
916972 ----------
@@ -926,6 +982,9 @@ def entropic_fused_gromov_barycenters(
926982 if loss_fun not in ('square_loss' , 'kl_loss' ):
927983 raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
928984
985+ if stop_criterion not in ['barycenter' , 'loss' ]:
986+ raise ValueError (f"Unknown `stop_criterion='{ stop_criterion } '`. Use one of: { 'barycenter' , 'loss' } ." )
987+
929988 Cs = list_to_array (* Cs )
930989 Ys = list_to_array (* Ys )
931990 arr = [* Cs , * Ys ]
@@ -945,67 +1004,108 @@ def entropic_fused_gromov_barycenters(
9451004
9461005 d = Ys [0 ].shape [1 ] # dimension on the node features
9471006
948- # Initialization of C : random SPD matrix (if not provided by user)
949- if init_C is None :
950- generator = check_random_state (random_state )
951- xalea = generator .randn (N , 2 )
952- C = dist (xalea , xalea )
953- C /= C .max ()
954- C = nx .from_numpy (C , type_as = p )
1007+ # Initialization of C : random euclidean distance matrix (if not provided by user)
1008+ if fixed_structure :
1009+ if init_C is None :
1010+ raise UndefinedParameter ('If C is fixed it must be initialized' )
1011+ else :
1012+ C = init_C
9551013 else :
956- C = init_C
1014+ if init_C is None :
1015+ generator = check_random_state (random_state )
1016+ xalea = generator .randn (N , 2 )
1017+ C = dist (xalea , xalea )
1018+ C = nx .from_numpy (C , type_as = ps [0 ])
1019+ else :
1020+ C = init_C
9571021
9581022 # Initialization of Y
959- if init_Y is None :
960- Y = nx .zeros ((N , d ), type_as = ps [0 ])
1023+ if fixed_features :
1024+ if init_Y is None :
1025+ raise UndefinedParameter ('If Y is fixed it must be initialized' )
1026+ else :
1027+ Y = init_Y
9611028 else :
962- Y = init_Y
1029+ if init_Y is None :
1030+ Y = nx .zeros ((N , d ), type_as = ps [0 ])
9631031
964- if warmstartT :
965- T = [ None ] * S
1032+ else :
1033+ Y = init_Y
9661034
9671035 Ms = [dist (Y , Ys [s ]) for s in range (len (Ys ))]
9681036
1037+ if warmstartT :
1038+ T = [None ] * S
1039+
9691040 cpt = 0
970- err = 1
9711041
972- err_feature = 1
973- err_structure = 1
1042+ if stop_criterion == 'barycenter' :
1043+ inner_log = False
1044+ err_feature = 1e15
1045+ err_structure = 1e15
1046+ err_rel_loss = 0.
1047+
1048+ else :
1049+ inner_log = True
1050+ err_feature = 0.
1051+ err_structure = 0.
1052+ curr_loss = 1e15
1053+ err_rel_loss = 1e15
9741054
9751055 if log :
9761056 log_ = {}
977- log_ ['err_feature' ] = []
978- log_ ['err_structure' ] = []
979- log_ ['Ts_iter' ] = []
1057+ if stop_criterion == 'barycenter' :
1058+ log_ ['err_feature' ] = []
1059+ log_ ['err_structure' ] = []
1060+ log_ ['Ts_iter' ] = []
1061+ else :
1062+ log_ ['loss' ] = []
1063+ log_ ['err_rel_loss' ] = []
9801064
981- while (err > tol ) and (cpt < max_iter ):
982- Cprev = C
983- Yprev = Y
1065+ while ((err_feature > tol or err_structure > tol or err_rel_loss > tol ) and cpt < max_iter ):
1066+ if stop_criterion == 'barycenter' :
1067+ Cprev = C
1068+ Yprev = Y
1069+ else :
1070+ prev_loss = curr_loss
9841071
1072+ # get transport plans
9851073 if warmstartT :
986- T = [entropic_fused_gromov_wasserstein (
1074+ res = [entropic_fused_gromov_wasserstein (
9871075 Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , epsilon , symmetric , alpha ,
988- T [s ], max_iter , 1e-4 , verbose = verbose , log = False , ** kwargs ) for s in range (S )]
1076+ T [s ], max_iter , 1e-4 , verbose = verbose , log = inner_log , ** kwargs ) for s in range (S )]
9891077
9901078 else :
991- T = [entropic_fused_gromov_wasserstein (
1079+ res = [entropic_fused_gromov_wasserstein (
9921080 Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , epsilon , symmetric , alpha ,
993- None , max_iter , 1e-4 , verbose = verbose , log = False , ** kwargs ) for s in range (S )]
1081+ None , max_iter , 1e-4 , verbose = verbose , log = inner_log , ** kwargs ) for s in range (S )]
9941082
995- if loss_fun == 'square_loss' :
996- C = update_square_loss (p , lambdas , T , Cs , nx )
997- elif loss_fun == 'kl_loss' :
998- C = update_kl_loss (p , lambdas , T , Cs , nx )
999-
1000- Ys_temp = [y .T for y in Ys ]
1001- Y = update_feature_matrix (lambdas , Ys_temp , T , p , nx ).T
1002- Ms = [dist (Y , Ys [s ]) for s in range (len (Ys ))]
1003-
1004- if cpt % 10 == 0 :
1005- # we can speed up the process by checking for the error only all
1006- # the 10th iterations
1007- err_feature = nx .norm (Y - nx .reshape (Yprev , (N , d )))
1008- err_structure = nx .norm (C - Cprev )
1083+ if stop_criterion == 'barycenter' :
1084+ T = res
1085+ else :
1086+ T = [output [0 ] for output in res ]
1087+ curr_loss = np .sum ([output [1 ]['fgw_dist' ] for output in res ])
1088+
1089+ # update barycenters
1090+ if not fixed_features :
1091+ Ys_temp = [y .T for y in Ys ]
1092+ X = update_feature_matrix (lambdas , Ys_temp , T , p , nx ).T
1093+ Ms = [dist (X , Ys [s ]) for s in range (len (Ys ))]
1094+
1095+ if not fixed_structure :
1096+ if loss_fun == 'square_loss' :
1097+ C = update_square_loss (p , lambdas , T , Cs , nx )
1098+
1099+ elif loss_fun == 'kl_loss' :
1100+ C = update_kl_loss (p , lambdas , T , Cs , nx )
1101+
1102+ # update convergence criterion
1103+ if stop_criterion == 'barycenter' :
1104+ err_feature , err_structure = 0. , 0.
1105+ if not fixed_features :
1106+ err_feature = nx .norm (Y - Yprev )
1107+ if not fixed_structure :
1108+ err_structure = nx .norm (C - Cprev )
10091109 if log :
10101110 log_ ['err_feature' ].append (err_feature )
10111111 log_ ['err_structure' ].append (err_structure )
@@ -1017,14 +1117,25 @@ def entropic_fused_gromov_barycenters(
10171117 'It.' , 'Err' ) + '\n ' + '-' * 19 )
10181118 print ('{:5d}|{:8e}|' .format (cpt , err_structure ))
10191119 print ('{:5d}|{:8e}|' .format (cpt , err_feature ))
1120+ else :
1121+ err_rel_loss = abs (curr_loss - prev_loss ) / prev_loss if prev_loss != 0. else np .nan
1122+ if log :
1123+ log_ ['loss' ].append (curr_loss )
1124+ log_ ['err_rel_loss' ].append (err_rel_loss )
1125+
1126+ if verbose :
1127+ if cpt % 200 == 0 :
1128+ print ('{:5s}|{:12s}' .format (
1129+ 'It.' , 'Err' ) + '\n ' + '-' * 19 )
1130+ print ('{:5d}|{:8e}|' .format (cpt , err_rel_loss ))
10201131
10211132 cpt += 1
1133+
10221134 if log :
1023- log_ ['T' ] = T # from target to Ys
1135+ log_ ['T' ] = T
10241136 log_ ['p' ] = p
10251137 log_ ['Ms' ] = Ms
10261138
1027- if log :
10281139 return Y , C , log_
10291140 else :
10301141 return Y , C
0 commit comments