Skip to content

Commit af9f1e3

Browse files
committed
BUG: EMDTransport parameter log unusable
1 parent f31d725 commit af9f1e3

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

ot/da.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport):
13321332
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13331333
"""
13341334

1335-
def __init__(self, metric="sqeuclidean", norm=None,
1335+
def __init__(self, metric="sqeuclidean", norm=None,log=False,
13361336
distribution_estimation=distribution_estimation_uniform,
13371337
out_of_sample_map='ferradans', limit_max=10,
13381338
max_iter=100000):
1339-
1339+
13401340
self.metric = metric
13411341
self.norm = norm
1342+
self.log = log
13421343
self.limit_max = limit_max
13431344
self.distribution_estimation = distribution_estimation
13441345
self.out_of_sample_map = out_of_sample_map
@@ -1371,11 +1372,16 @@ class label
13711372

13721373
super(EMDTransport, self).fit(Xs, ys, Xt, yt)
13731374

1374-
# coupling estimation
1375-
self.coupling_ = emd(
1376-
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
1377-
)
1375+
returned_ = emd(
1376+
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter,
1377+
log=self.log)
13781378

1379+
# coupling estimation
1380+
if self.log:
1381+
self.coupling_, self.log_ = returned_
1382+
else:
1383+
self.coupling_ = returned_
1384+
self.log_ = dict()
13791385
return self
13801386

13811387

0 commit comments

Comments
 (0)