@@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport):
13321332 on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13331333 """
13341334
1335- def __init__ (self , metric = "sqeuclidean" , norm = None ,
1335+ def __init__ (self , metric = "sqeuclidean" , norm = None ,log = False ,
13361336 distribution_estimation = distribution_estimation_uniform ,
13371337 out_of_sample_map = 'ferradans' , limit_max = 10 ,
13381338 max_iter = 100000 ):
1339-
1339+
13401340 self .metric = metric
13411341 self .norm = norm
1342+ self .log = log
13421343 self .limit_max = limit_max
13431344 self .distribution_estimation = distribution_estimation
13441345 self .out_of_sample_map = out_of_sample_map
@@ -1371,11 +1372,16 @@ class label
13711372
13721373 super (EMDTransport , self ).fit (Xs , ys , Xt , yt )
13731374
1374- # coupling estimation
1375- self .coupling_ = emd (
1376- a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self .max_iter
1377- )
1375+ returned_ = emd (
1376+ a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self .max_iter ,
1377+ log = self .log )
13781378
1379+ # coupling estimation
1380+ if self .log :
1381+ self .coupling_ , self .log_ = returned_
1382+ else :
1383+ self .coupling_ = returned_
1384+ self .log_ = dict ()
13791385 return self
13801386
13811387
0 commit comments