Skip to content

Commit 9473929

Browse files
authored
Merge pull request #44 from tardyb/addLog
Using log parameter, fixes #43 Thank you @tardyb
2 parents f31d725 + 54e16a4 commit 9473929

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

ot/da.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport):
13321332
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13331333
"""
13341334

1335-
def __init__(self, metric="sqeuclidean", norm=None,
1335+
def __init__(self, metric="sqeuclidean", norm=None, log=False,
13361336
distribution_estimation=distribution_estimation_uniform,
13371337
out_of_sample_map='ferradans', limit_max=10,
13381338
max_iter=100000):
13391339

13401340
self.metric = metric
13411341
self.norm = norm
1342+
self.log = log
13421343
self.limit_max = limit_max
13431344
self.distribution_estimation = distribution_estimation
13441345
self.out_of_sample_map = out_of_sample_map
@@ -1371,11 +1372,16 @@ class label
13711372

13721373
super(EMDTransport, self).fit(Xs, ys, Xt, yt)
13731374

1374-
# coupling estimation
1375-
self.coupling_ = emd(
1376-
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
1377-
)
1375+
returned_ = emd(
1376+
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter,
1377+
log=self.log)
13781378

1379+
# coupling estimation
1380+
if self.log:
1381+
self.coupling_, self.log_ = returned_
1382+
else:
1383+
self.coupling_ = returned_
1384+
self.log_ = dict()
13791385
return self
13801386

13811387

@@ -1432,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport):
14321438
"""
14331439

14341440
def __init__(self, reg_e=1., reg_cl=0.1,
1435-
max_iter=10, max_inner_iter=200,
1441+
max_iter=10, max_inner_iter=200, log=False,
14361442
tol=10e-9, verbose=False,
14371443
metric="sqeuclidean", norm=None,
14381444
distribution_estimation=distribution_estimation_uniform,
@@ -1443,6 +1449,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14431449
self.max_iter = max_iter
14441450
self.max_inner_iter = max_inner_iter
14451451
self.tol = tol
1452+
self.log = log
14461453
self.verbose = verbose
14471454
self.metric = metric
14481455
self.norm = norm
@@ -1480,12 +1487,18 @@ class label
14801487

14811488
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
14821489

1483-
self.coupling_ = sinkhorn_lpl1_mm(
1490+
returned_ = sinkhorn_lpl1_mm(
14841491
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
14851492
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
14861493
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
1487-
verbose=self.verbose)
1494+
verbose=self.verbose, log=self.log)
14881495

1496+
# deal with the value of log
1497+
if self.log:
1498+
self.coupling_, self.log_ = returned_
1499+
else:
1500+
self.coupling_ = returned_
1501+
self.log_ = dict()
14891502
return self
14901503

14911504

0 commit comments

Comments
 (0)