@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748 return A , b
749749
750750
751- def emd_laplace (a , b , xs , xt , M , reg , eta , alpha ,
751+ def emd_laplace (a , b , xs , xt , M , sim , sim_param , reg , eta , alpha ,
752752 numItermax , stopThr , numInnerItermax ,
753753 stopInnerThr , log = False , verbose = False , ** kwargs ):
754754 r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,11 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
785785 samples in the target domain
786786 M : np.ndarray (ns,nt)
787787 loss matrix
788+ sim : string, optional
789+ Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
790+ sim_param : int or float, optional
791+ Parameter (number of the nearest neighbors for sim='knn'
792+ or bandwidth for sim='gauss' used to compute the Laplacian.
788793 reg : string
789794 Type of Laplacian regularization
790795 eta : float
@@ -803,11 +808,6 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
803808 Print information along iterations
804809 log : bool, optional
805810 record log if True
806- kwargs : dict
807- Dictionary with attributes 'sim' ('knn' or 'gauss') and
808- 'param' (int, float or None) for similarity type and its parameter to be used.
809- If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
810- or set to 3 when sim is 'gauss' or 'knn', respectively.
811811
812812 Returns
813813 -------
@@ -824,7 +824,7 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
824824 "Optimal Transport for Domain Adaptation," in IEEE
825825 Transactions on Pattern Analysis and Machine Intelligence ,
826826 vol.PP, no.99, pp.1-1
827- .. [28 ] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
827+ .. [30 ] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
828828 "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
829829 in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
830830
@@ -834,28 +834,28 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
834834 ot.optim.cg : General regularized OT
835835
836836 """
837- if not isinstance (kwargs [ 'param' ] , (int , float , type (None ))):
837+ if not isinstance (sim_param , (int , float , type (None ))):
838838 raise ValueError (
839- 'Similarity parameter should be an int or a float. Got {type} instead.' .format (type = type (kwargs [ 'param' ]) ))
839+ 'Similarity parameter should be an int or a float. Got {type} instead.' .format (type = type (sim_param ). __name__ ))
840840
841- if kwargs [ ' sim' ] == 'gauss' :
842- if kwargs [ 'param' ] is None :
843- kwargs [ 'param' ] = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
844- sS = kernel (xs , xs , method = kwargs [ ' sim' ] , sigma = kwargs [ 'param' ] )
845- sT = kernel (xt , xt , method = kwargs [ ' sim' ] , sigma = kwargs [ 'param' ] )
841+ if sim == 'gauss' :
842+ if sim_param is None :
843+ sim_param = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
844+ sS = kernel (xs , xs , method = sim , sigma = sim_param )
845+ sT = kernel (xt , xt , method = sim , sigma = sim_param )
846846
847- elif kwargs [ ' sim' ] == 'knn' :
848- if kwargs [ 'param' ] is None :
849- kwargs [ 'param' ] = 3
847+ elif sim == 'knn' :
848+ if sim_param is None :
849+ sim_param = 3
850850
851851 from sklearn .neighbors import kneighbors_graph
852852
853- sS = kneighbors_graph (X = xs , n_neighbors = int (kwargs [ 'param' ] )).toarray ()
853+ sS = kneighbors_graph (X = xs , n_neighbors = int (sim_param )).toarray ()
854854 sS = (sS + sS .T ) / 2
855- sT = kneighbors_graph (xt , n_neighbors = int (kwargs [ 'param' ] )).toarray ()
855+ sT = kneighbors_graph (xt , n_neighbors = int (sim_param )).toarray ()
856856 sT = (sT + sT .T ) / 2
857857 else :
858- raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = kwargs [ ' sim' ] ))
858+ raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = sim ))
859859
860860 lS = laplacian (sS )
861861 lT = laplacian (sT )
@@ -1729,9 +1729,10 @@ class EMDLaplaceTransport(BaseTransport):
17291729 can occur with large metric values.
17301730 similarity : string, optional (default="knn")
17311731 The similarity to use either knn or gaussian
1732- similarity_param : int or float, optional (default=3 )
1732+ similarity_param : int or float, optional (default=None )
17331733 Parameter for the similarity: number of nearest neighbors or bandwidth
1734- if similarity="knn" or "gaussian", respectively.
1734+ if similarity="knn" or "gaussian", respectively. If None is provided,
1735+ it is set to 3 or the average pairwise squared Euclidean distance, respectively.
17351736 max_iter : int, optional (default=100)
17361737 Max number of BCD iterations
17371738 tol : float, optional (default=1e-5)
@@ -1813,14 +1814,10 @@ class label
18131814
18141815 super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
18151816
1816- kwargs = dict ()
1817- kwargs ["sim" ] = self .similarity
1818- kwargs ["param" ] = self .sim_param
1819-
18201817 returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1821- xt = self .xt_ , M = self .cost_ , reg = self .reg , eta = self .reg_lap , alpha = self .reg_src ,
1818+ xt = self .xt_ , M = self .cost_ , sim = self . similarity , sim_param = self . sim_param , reg = self .reg , eta = self .reg_lap , alpha = self .reg_src ,
18221819 numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1823- stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose , ** kwargs )
1820+ stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
18241821
18251822 # coupling estimation
18261823 if self .log :
0 commit comments