Skip to content

Commit 1269033

Browse files
author
ievred
committed
added regulrization from [6]+fix other issues
1 parent 14fbb88 commit 1269033

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

ot/da.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .bregman import sinkhorn, jcpot_barycenter
1818
from .lp import emd
19-
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian
19+
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
2020
from .utils import check_params, BaseEstimator
2121
from .unbalanced import sinkhorn_unbalanced
2222
from .optim import cg
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748
return A, b
749749

750750

751-
def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
751+
def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
752752
numItermax, stopThr, numInnerItermax,
753753
stopInnerThr, log=False, verbose=False, **kwargs):
754754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
785785
samples in the target domain
786786
M : np.ndarray (ns,nt)
787787
loss matrix
788+
reg : string
789+
Type of Laplacian regularization
788790
eta : float
789791
Regularization term for Laplacian regularization
790792
alpha : float
@@ -844,6 +846,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
844846
sS = (sS + sS.T) / 2
845847
sT = kneighbors_graph(xt, kwargs['nn']).toarray()
846848
sT = (sT + sT.T) / 2
849+
else:
850+
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
847851

848852
lS = laplacian(sS)
849853
lT = laplacian(sT)
@@ -852,9 +856,18 @@ def f(G):
852856
return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
853857
+ (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
854858

859+
ls2 = lS + lS.T
860+
lt2 = lT + lT.T
861+
xt2 = np.dot(xt, xt.T)
862+
863+
if reg == 'disp':
864+
Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T)
865+
Ct = -eta * (1 - alpha) / xt.shape[0] * dots(xs, xt.T, lt2)
866+
M = M + Cs + Ct
867+
855868
def df(G):
856-
return alpha * np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\
857-
+ (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T)))
869+
return alpha * np.dot(ls2, np.dot(G, xt2))\
870+
+ (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2)))
858871

859872
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
860873
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
@@ -1694,6 +1707,9 @@ class EMDLaplaceTransport(BaseTransport):
16941707
16951708
Parameters
16961709
----------
1710+
reg_type : string optional (default='pos')
1711+
Type of the regularization term: 'pos' and 'disp' for
1712+
regularization term defined in [2] and [6], respectively.
16971713
reg_lap : float, optional (default=1)
16981714
Laplacian regularization parameter
16991715
reg_src : float, optional (default=0.5)
@@ -1737,11 +1753,12 @@ class EMDLaplaceTransport(BaseTransport):
17371753
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
17381754
"""
17391755

1740-
def __init__(self, reg_lap=1., reg_src=1., alpha=0.5,
1756+
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
17411757
metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9,
17421758
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
17431759
distribution_estimation=distribution_estimation_uniform,
17441760
out_of_sample_map='ferradans'):
1761+
self.reg = reg_type
17451762
self.reg_lap = reg_lap
17461763
self.reg_src = reg_src
17471764
self.alpha = alpha
@@ -1785,7 +1802,7 @@ class label
17851802
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
17861803

17871804
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
1788-
xt=self.xt_, M=self.cost_, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
1805+
xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
17891806
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
17901807
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
17911808

0 commit comments

Comments
 (0)