Skip to content

Commit 0746328

Browse files
author
ievred
committed
added kwargs to sim + doc
1 parent 1269033 commit 0746328

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

ot/da.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748
return A, b
749749

750750

751-
def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
751+
def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
752752
numItermax, stopThr, numInnerItermax,
753753
stopInnerThr, log=False, verbose=False, **kwargs):
754754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -803,7 +803,11 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
803803
Print information along iterations
804804
log : bool, optional
805805
record log if True
806-
806+
kwargs : dict
807+
Dictionary with attributes 'sim' ('knn' or 'gauss') and
808+
'param' (int, float or None) for similarity type and its parameter to be used.
809+
If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
810+
or set to 3 when sim is 'gauss' or 'knn', respectively.
807811
808812
Returns
809813
-------
@@ -830,24 +834,28 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
830834
ot.optim.cg : General regularized OT
831835
832836
"""
833-
if sim == 'gauss':
834-
if 'rbfparam' not in kwargs:
835-
kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
836-
sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam'])
837-
sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam'])
837+
if not isinstance(kwargs['param'], (int, float, type(None))):
838+
raise ValueError(
839+
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param'])))
840+
841+
if kwargs['sim'] == 'gauss':
842+
if kwargs['param'] is None:
843+
kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
844+
sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param'])
845+
sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param'])
838846

839-
elif sim == 'knn':
840-
if 'nn' not in kwargs:
841-
kwargs['nn'] = 5
847+
elif kwargs['sim'] == 'knn':
848+
if kwargs['param'] is None:
849+
kwargs['param'] = 3
842850

843851
from sklearn.neighbors import kneighbors_graph
844852

845-
sS = kneighbors_graph(xs, kwargs['nn']).toarray()
853+
sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray()
846854
sS = (sS + sS.T) / 2
847-
sT = kneighbors_graph(xt, kwargs['nn']).toarray()
855+
sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray()
848856
sT = (sT + sT.T) / 2
849857
else:
850-
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
858+
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim']))
851859

852860
lS = laplacian(sS)
853861
lT = laplacian(sT)
@@ -1721,6 +1729,9 @@ class EMDLaplaceTransport(BaseTransport):
17211729
can occur with large metric values.
17221730
similarity : string, optional (default="knn")
17231731
The similarity to use either knn or gaussian
1732+
similarity_param : int or float, optional (default=3)
1733+
Parameter for the similarity: number of nearest neighbors or bandwidth
1734+
if similarity="knn" or "gaussian", respectively.
17241735
max_iter : int, optional (default=100)
17251736
Max number of BCD iterations
17261737
tol : float, optional (default=1e-5)
@@ -1754,7 +1765,7 @@ class EMDLaplaceTransport(BaseTransport):
17541765
"""
17551766

17561767
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
1757-
metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9,
1768+
metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
17581769
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
17591770
distribution_estimation=distribution_estimation_uniform,
17601771
out_of_sample_map='ferradans'):
@@ -1765,6 +1776,7 @@ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
17651776
self.metric = metric
17661777
self.norm = norm
17671778
self.similarity = similarity
1779+
self.sim_param = similarity_param
17681780
self.max_iter = max_iter
17691781
self.tol = tol
17701782
self.max_inner_iter = max_inner_iter
@@ -1801,10 +1813,14 @@ class label
18011813

18021814
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
18031815

1816+
kwargs = dict()
1817+
kwargs["sim"] = self.similarity
1818+
kwargs["param"] = self.sim_param
1819+
18041820
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
1805-
xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
1821+
xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
18061822
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
1807-
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
1823+
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs)
18081824

18091825
# coupling estimation
18101826
if self.log:

0 commit comments

Comments
 (0)