@@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport):
13321332 on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13331333 """
13341334
1335- def __init__ (self , metric = "sqeuclidean" , norm = None ,
1335+ def __init__ (self , metric = "sqeuclidean" , norm = None , log = False ,
13361336 distribution_estimation = distribution_estimation_uniform ,
13371337 out_of_sample_map = 'ferradans' , limit_max = 10 ,
13381338 max_iter = 100000 ):
13391339
13401340 self .metric = metric
13411341 self .norm = norm
1342+ self .log = log
13421343 self .limit_max = limit_max
13431344 self .distribution_estimation = distribution_estimation
13441345 self .out_of_sample_map = out_of_sample_map
@@ -1371,11 +1372,16 @@ class label
13711372
13721373 super (EMDTransport , self ).fit (Xs , ys , Xt , yt )
13731374
1374- # coupling estimation
1375- self .coupling_ = emd (
1376- a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self .max_iter
1377- )
1375+ returned_ = emd (
1376+ a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self .max_iter ,
1377+ log = self .log )
13781378
1379+ # coupling estimation
1380+ if self .log :
1381+ self .coupling_ , self .log_ = returned_
1382+ else :
1383+ self .coupling_ = returned_
1384+ self .log_ = dict ()
13791385 return self
13801386
13811387
@@ -1432,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport):
14321438 """
14331439
14341440 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1435- max_iter = 10 , max_inner_iter = 200 ,
1441+ max_iter = 10 , max_inner_iter = 200 , log = False ,
14361442 tol = 10e-9 , verbose = False ,
14371443 metric = "sqeuclidean" , norm = None ,
14381444 distribution_estimation = distribution_estimation_uniform ,
@@ -1443,6 +1449,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14431449 self .max_iter = max_iter
14441450 self .max_inner_iter = max_inner_iter
14451451 self .tol = tol
1452+ self .log = log
14461453 self .verbose = verbose
14471454 self .metric = metric
14481455 self .norm = norm
@@ -1480,12 +1487,18 @@ class label
14801487
14811488 super (SinkhornLpl1Transport , self ).fit (Xs , ys , Xt , yt )
14821489
1483- self . coupling_ = sinkhorn_lpl1_mm (
1490+ returned_ = sinkhorn_lpl1_mm (
14841491 a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
14851492 reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
14861493 numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1487- verbose = self .verbose )
1494+ verbose = self .verbose , log = self . log )
14881495
1496+ # deal with the value of log
1497+ if self .log :
1498+ self .coupling_ , self .log_ = returned_
1499+ else :
1500+ self .coupling_ = returned_
1501+ self .log_ = dict ()
14891502 return self
14901503
14911504
0 commit comments