@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748 return A , b
749749
750750
751- def emd_laplace (a , b , xs , xt , M , sim , reg , eta , alpha ,
751+ def emd_laplace (a , b , xs , xt , M , reg , eta , alpha ,
752752 numItermax , stopThr , numInnerItermax ,
753753 stopInnerThr , log = False , verbose = False , ** kwargs ):
754754 r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -803,7 +803,11 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
803803 Print information along iterations
804804 log : bool, optional
805805 record log if True
806-
806+ kwargs : dict
807+ Dictionary with attributes 'sim' ('knn' or 'gauss') and
808+ 'param' (int, float or None) for similarity type and its parameter to be used.
809+ If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
810+ or set to 3 when sim is 'gauss' or 'knn', respectively.
807811
808812 Returns
809813 -------
@@ -830,24 +834,28 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
830834 ot.optim.cg : General regularized OT
831835
832836 """
833- if sim == 'gauss' :
834- if 'rbfparam' not in kwargs :
835- kwargs ['rbfparam' ] = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
836- sS = kernel (xs , xs , method = kwargs ['sim' ], sigma = kwargs ['rbfparam' ])
837- sT = kernel (xt , xt , method = kwargs ['sim' ], sigma = kwargs ['rbfparam' ])
837+ if not isinstance (kwargs ['param' ], (int , float , type (None ))):
838+ raise ValueError (
839+ 'Similarity parameter should be an int or a float. Got {type} instead.' .format (type = type (kwargs ['param' ])))
840+
841+ if kwargs ['sim' ] == 'gauss' :
842+ if kwargs ['param' ] is None :
843+ kwargs ['param' ] = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
844+ sS = kernel (xs , xs , method = kwargs ['sim' ], sigma = kwargs ['param' ])
845+ sT = kernel (xt , xt , method = kwargs ['sim' ], sigma = kwargs ['param' ])
838846
839- elif sim == 'knn' :
840- if 'nn' not in kwargs :
841- kwargs ['nn ' ] = 5
847+ elif kwargs [ ' sim' ] == 'knn' :
848+ if kwargs [ 'param' ] is None :
849+ kwargs ['param ' ] = 3
842850
843851 from sklearn .neighbors import kneighbors_graph
844852
845- sS = kneighbors_graph (xs , kwargs ['nn' ] ).toarray ()
853+ sS = kneighbors_graph (X = xs , n_neighbors = int ( kwargs ['param' ]) ).toarray ()
846854 sS = (sS + sS .T ) / 2
847- sT = kneighbors_graph (xt , kwargs ['nn' ] ).toarray ()
855+ sT = kneighbors_graph (xt , n_neighbors = int ( kwargs ['param' ]) ).toarray ()
848856 sT = (sT + sT .T ) / 2
849857 else :
850- raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = sim ))
858+ raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = kwargs [ ' sim' ] ))
851859
852860 lS = laplacian (sS )
853861 lT = laplacian (sT )
@@ -1721,6 +1729,9 @@ class EMDLaplaceTransport(BaseTransport):
17211729 can occur with large metric values.
17221730 similarity : string, optional (default="knn")
17231731 The similarity to use either knn or gaussian
1732+ similarity_param : int or float, optional (default=3)
1733+ Parameter for the similarity: number of nearest neighbors or bandwidth
1734+ if similarity="knn" or "gaussian", respectively.
17241735 max_iter : int, optional (default=100)
17251736 Max number of BCD iterations
17261737 tol : float, optional (default=1e-5)
@@ -1754,7 +1765,7 @@ class EMDLaplaceTransport(BaseTransport):
17541765 """
17551766
17561767 def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , alpha = 0.5 ,
1757- metric = "sqeuclidean" , norm = None , similarity = "knn" , max_iter = 100 , tol = 1e-9 ,
1768+ metric = "sqeuclidean" , norm = None , similarity = "knn" , similarity_param = None , max_iter = 100 , tol = 1e-9 ,
17581769 max_inner_iter = 100000 , inner_tol = 1e-9 , log = False , verbose = False ,
17591770 distribution_estimation = distribution_estimation_uniform ,
17601771 out_of_sample_map = 'ferradans' ):
@@ -1765,6 +1776,7 @@ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
17651776 self .metric = metric
17661777 self .norm = norm
17671778 self .similarity = similarity
1779+ self .sim_param = similarity_param
17681780 self .max_iter = max_iter
17691781 self .tol = tol
17701782 self .max_inner_iter = max_inner_iter
@@ -1801,10 +1813,14 @@ class label
18011813
18021814 super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
18031815
1816+ kwargs = dict ()
1817+ kwargs ["sim" ] = self .similarity
1818+ kwargs ["param" ] = self .sim_param
1819+
18041820 returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1805- xt = self .xt_ , M = self .cost_ , reg = self .reg , sim = self . similarity , eta = self .reg_lap , alpha = self .reg_src ,
1821+ xt = self .xt_ , M = self .cost_ , reg = self .reg , eta = self .reg_lap , alpha = self .reg_src ,
18061822 numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1807- stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
1823+ stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose , ** kwargs )
18081824
18091825 # coupling estimation
18101826 if self .log :
0 commit comments