@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
8181 q : array-like, shape (nt,), optional
8282 Distribution in the target space.
8383 If let to its default value None, uniform distribution is taken.
84- loss_fun : string, optional
84+ loss_fun : string, optional (default='square_loss')
8585 Loss function used for the solver either 'square_loss' or 'kl_loss'
8686 epsilon : float, optional
8787 Regularization term >0
@@ -92,8 +92,8 @@ def entropic_gromov_wasserstein(
9292 G0: array-like, shape (ns,nt), optional
9393 If None the initial transport plan of the solver is pq^T.
9494 Otherwise G0 will be used as initial transport of the solver. G0 is not
95- required to satisfy marginal constraints but we strongly recommand it
96- to correcly estimate the GW distance.
95+ required to satisfy marginal constraints but we strongly recommend it
96+ to correctly estimate the GW distance.
9797 max_iter : int, optional
9898 Max number of iterations
9999 tol : float, optional
@@ -135,6 +135,9 @@ def entropic_gromov_wasserstein(
135135 if solver not in ['PGD' , 'PPA' ]:
136136 raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
137137
138+ if loss_fun not in ('square_loss' , 'kl_loss' ):
139+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
140+
138141 C1 , C2 = list_to_array (C1 , C2 )
139142 arr = [C1 , C2 ]
140143 if p is not None :
@@ -280,7 +283,7 @@ def entropic_gromov_wasserstein2(
280283 q : array-like, shape (nt,), optional
281284 Distribution in the target space.
282285 If let to its default value None, uniform distribution is taken.
283- loss_fun : string, optional
286+ loss_fun : string, optional (default='square_loss')
284287 Loss function used for the solver either 'square_loss' or 'kl_loss'
285288 epsilon : float, optional
286289 Regularization term >0
@@ -373,8 +376,8 @@ def entropic_gromov_barycenters(
373376 lambdas : list of float, optional
374377 List of the `S` spaces' weights.
375378 If let to its default value None, uniform weights are taken.
376- loss_fun : callable , optional
377- tensor-matrix multiplication function based on specific loss function
379+ loss_fun : string , optional (default='square_loss')
380+ Loss function used for the solver either 'square_loss' or 'kl_loss'
378381 epsilon : float, optional
379382 Regularization term >0
380383 symmetric : bool, optional.
@@ -411,6 +414,9 @@ def entropic_gromov_barycenters(
411414 "Gromov-Wasserstein averaging of kernel and distance matrices."
412415 International Conference on Machine Learning (ICML). 2016.
413416 """
417+ if loss_fun not in ('square_loss' , 'kl_loss' ):
418+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
419+
414420 Cs = list_to_array (* Cs )
415421 arr = [* Cs ]
416422 if ps is not None :
@@ -459,7 +465,6 @@ def entropic_gromov_barycenters(
459465
460466 if loss_fun == 'square_loss' :
461467 C = update_square_loss (p , lambdas , T , Cs )
462-
463468 elif loss_fun == 'kl_loss' :
464469 C = update_kl_loss (p , lambdas , T , Cs )
465470
@@ -550,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
550555 q : array-like, shape (nt,), optional
551556 Distribution in the target space.
552557 If let to its default value None, uniform distribution is taken.
553- loss_fun : string, optional
558+ loss_fun : string, optional (default='square_loss')
554559 Loss function used for the solver either 'square_loss' or 'kl_loss'
555560 epsilon : float, optional
556561 Regularization term >0
557562 symmetric : bool, optional
558563 Either C1 and C2 are to be assumed symmetric or not.
559564 If let to its default None value, a symmetry test will be conducted.
560- Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric ).
565+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric ).
561566 alpha : float, optional
562567 Trade-off parameter (0 < alpha < 1)
563568 G0: array-like, shape (ns,nt), optional
564569 If None the initial transport plan of the solver is pq^T.
565570 Otherwise G0 will be used as initial transport of the solver. G0 is not
566- required to satisfy marginal constraints but we strongly recommand it
567- to correcly estimate the GW distance.
571+ required to satisfy marginal constraints but we strongly recommend it
572+ to correctly estimate the GW distance.
568573 max_iter : int, optional
569574 Max number of iterations
570575 tol : float, optional
@@ -611,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
611616 if solver not in ['PGD' , 'PPA' ]:
612617 raise ValueError ("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver )
613618
619+ if loss_fun not in ('square_loss' , 'kl_loss' ):
620+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
621+
614622 M , C1 , C2 = list_to_array (M , C1 , C2 )
615623 arr = [M , C1 , C2 ]
616624 if p is not None :
@@ -762,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
762770 q : array-like, shape (nt,), optional
763771 Distribution in the target space.
764772 If let to its default value None, uniform distribution is taken.
765- loss_fun : string, optional
773+ loss_fun : string, optional (default='square_loss')
766774 Loss function used for the solver either 'square_loss' or 'kl_loss'
767775 epsilon : float, optional
768776 Regularization term >0
@@ -775,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
775783 G0: array-like, shape (ns,nt), optional
776784 If None the initial transport plan of the solver is pq^T.
777785 Otherwise G0 will be used as initial transport of the solver. G0 is not
778- required to satisfy marginal constraints but we strongly recommand it
779- to correcly estimate the GW distance.
786+ required to satisfy marginal constraints but we strongly recommend it
787+ to correctly estimate the GW distance.
780788 max_iter : int, optional
781789 Max number of iterations
782790 tol : float, optional
@@ -857,8 +865,8 @@ def entropic_fused_gromov_barycenters(
857865 lambdas : list of float, optional
858866 List of the `S` spaces' weights.
859867 If let to its default value None, uniform weights are taken.
860- loss_fun : callable , optional
861- tensor-matrix multiplication function based on specific loss function
868+ loss_fun : string , optional (default='square_loss')
869+ Loss function used for the solver either 'square_loss' or 'kl_loss'
862870 epsilon : float, optional
863871 Regularization term >0
864872 symmetric : bool, optional.
@@ -907,6 +915,9 @@ def entropic_fused_gromov_barycenters(
907915 "Optimal Transport for structured data with application on graphs"
908916 International Conference on Machine Learning (ICML). 2019.
909917 """
918+ if loss_fun not in ('square_loss' , 'kl_loss' ):
919+ raise ValueError (f"Unknown `loss_fun='{ loss_fun } '`. Use one of: { 'square_loss' , 'kl_loss' } ." )
920+
910921 Cs = list_to_array (* Cs )
911922 Ys = list_to_array (* Ys )
912923 arr = [* Cs , * Ys ]
@@ -977,7 +988,6 @@ def entropic_fused_gromov_barycenters(
977988
978989 if loss_fun == 'square_loss' :
979990 C = update_square_loss (p , lambdas , T , Cs )
980-
981991 elif loss_fun == 'kl_loss' :
982992 C = update_kl_loss (p , lambdas , T , Cs )
983993
0 commit comments