@@ -1310,7 +1310,7 @@ class EMDTransport(BaseTransport):
13101310 The kind of distribution estimation to employ
13111311 verbose : int, optional (default=0)
13121312 Controls the verbosity of the optimization algorithm
1313- log : bool , optional (default=False )
1313+ log : int , optional (default=0 )
13141314 Controls the logs of the optimization algorithm
13151315 limit_max: float, optional (default=10)
13161316 Controls the semi supervised mode. Transport between labeled source
@@ -1438,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport):
14381438 """
14391439
14401440 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1441- max_iter = 10 , max_inner_iter = 200 ,
1441+ max_iter = 10 , max_inner_iter = 200 , lo = False ,
14421442 tol = 10e-9 , verbose = False ,
14431443 metric = "sqeuclidean" , norm = None ,
14441444 distribution_estimation = distribution_estimation_uniform ,
@@ -1449,6 +1449,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14491449 self .max_iter = max_iter
14501450 self .max_inner_iter = max_inner_iter
14511451 self .tol = tol
1452+ self .log = log
14521453 self .verbose = verbose
14531454 self .metric = metric
14541455 self .norm = norm
@@ -1486,12 +1487,18 @@ class label
14861487
14871488 super (SinkhornLpl1Transport , self ).fit (Xs , ys , Xt , yt )
14881489
1489- self . coupling_ = sinkhorn_lpl1_mm (
1490+ returned_ = sinkhorn_lpl1_mm (
14901491 a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
14911492 reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
14921493 numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1493- verbose = self .verbose )
1494+ verbose = self .verbose , log = self . log )
14941495
1496+ # deal with the value of log
1497+ if self .log :
1498+ self .coupling_ , self .log_ = returned_
1499+ else :
1500+ self .coupling_ = returned_
1501+ self .log_ = dict ()
14951502 return self
14961503
14971504
0 commit comments