@@ -748,9 +748,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748 return A , b
749749
750750
751- def emd_laplace (a , b , xs , xt , M , sim , sim_param , reg , eta , alpha ,
752- numItermax , stopThr , numInnerItermax ,
753- stopInnerThr , log = False , verbose = False , ** kwargs ):
751+ def emd_laplace (a , b , xs , xt , M , sim = 'knn' , sim_param = None , reg = 'pos' , eta = 1 , alpha = .5 ,
752+ numItermax = 100 , stopThr = 1e-9 , numInnerItermax = 100000 ,
753+ stopInnerThr = 1e-9 , log = False , verbose = False ):
754754 r"""Solve the optimal transport problem (OT) with Laplacian regularization
755755
756756 .. math::
@@ -1765,15 +1765,14 @@ class EMDLaplaceTransport(BaseTransport):
17651765 in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
17661766 """
17671767
1768- def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , alpha = 0.5 ,
1769- metric = "sqeuclidean" , norm = None , similarity = "knn" , similarity_param = None , max_iter = 100 , tol = 1e-9 ,
1768+ def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , metric = "sqeuclidean" ,
1769+ norm = None , similarity = "knn" , similarity_param = None , max_iter = 100 , tol = 1e-9 ,
17701770 max_inner_iter = 100000 , inner_tol = 1e-9 , log = False , verbose = False ,
17711771 distribution_estimation = distribution_estimation_uniform ,
17721772 out_of_sample_map = 'ferradans' ):
17731773 self .reg = reg_type
17741774 self .reg_lap = reg_lap
17751775 self .reg_src = reg_src
1776- self .alpha = alpha
17771776 self .metric = metric
17781777 self .norm = norm
17791778 self .similarity = similarity
@@ -1815,8 +1814,8 @@ class label
18151814 super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
18161815
18171816 returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1818- xt = self .xt_ , M = self .cost_ , sim = self .similarity , sim_param = self .sim_param , reg = self .reg , eta = self .reg_lap , alpha = self . reg_src ,
1819- numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1817+ xt = self .xt_ , M = self .cost_ , sim = self .similarity , sim_param = self .sim_param , reg = self .reg , eta = self .reg_lap ,
1818+ alpha = self . reg_src , numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
18201819 stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
18211820
18221821 # coupling estimation
0 commit comments