Skip to content

Commit 94b929b

Browse files
committed
BUG: Parameter log unusable in sinkhorn classes
1 parent 1b5112c commit 94b929b

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

ot/da.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ class EMDTransport(BaseTransport):
13101310
The kind of distribution estimation to employ
13111311
verbose : int, optional (default=0)
13121312
Controls the verbosity of the optimization algorithm
1313-
log : bool, optional (default=False)
1313+
log : int, optional (default=0)
13141314
Controls the logs of the optimization algorithm
13151315
limit_max: float, optional (default=10)
13161316
Controls the semi supervised mode. Transport between labeled source
@@ -1438,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport):
14381438
"""
14391439

14401440
def __init__(self, reg_e=1., reg_cl=0.1,
1441-
max_iter=10, max_inner_iter=200,
1441+
max_iter=10, max_inner_iter=200, lo=False,
14421442
tol=10e-9, verbose=False,
14431443
metric="sqeuclidean", norm=None,
14441444
distribution_estimation=distribution_estimation_uniform,
@@ -1449,6 +1449,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14491449
self.max_iter = max_iter
14501450
self.max_inner_iter = max_inner_iter
14511451
self.tol = tol
1452+
self.log = log
14521453
self.verbose = verbose
14531454
self.metric = metric
14541455
self.norm = norm
@@ -1486,12 +1487,18 @@ class label
14861487

14871488
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
14881489

1489-
self.coupling_ = sinkhorn_lpl1_mm(
1490+
returned_ = sinkhorn_lpl1_mm(
14901491
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
14911492
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
14921493
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
1493-
verbose=self.verbose)
1494+
verbose=self.verbose,log=self.log)
14941495

1496+
# deal with the value of log
1497+
if self.log:
1498+
self.coupling_, self.log_ = returned_
1499+
else:
1500+
self.coupling_ = returned_
1501+
self.log_ = dict()
14951502
return self
14961503

14971504

0 commit comments

Comments
 (0)