Skip to content

Commit c68b52d

Browse files
author
ievred
committed
remove laplace from jcpot
1 parent 2c9f992 commit c68b52d

File tree

5 files changed

+5
-403
lines changed

5 files changed

+5
-403
lines changed

examples/plot_otda_jcpot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def print_G(G, xs, ys, xt):
115115
##############################################################################
116116
# Instantiate JCPOT adaptation algorithm and fit it
117117
# ----------------------------------------------------------------------------
118-
otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
118+
otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
119119
otda.fit(all_Xr, all_Yr, xt)
120120

121121
ws1 = otda.proportions_.dot(otda.log_['D2'][0])
@@ -126,8 +126,8 @@ def print_G(G, xs, ys, xt):
126126
plot_ax(dec1, 'Source 1')
127127
plot_ax(dec2, 'Source 2')
128128
plot_ax(dect, 'Target')
129-
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
130-
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
129+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
130+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
131131
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
132132
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
133133
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -154,8 +154,8 @@ def print_G(G, xs, ys, xt):
154154
plot_ax(dec1, 'Source 1')
155155
plot_ax(dec2, 'Source 2')
156156
plot_ax(dect, 'Target')
157-
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
158-
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
157+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
158+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
159159
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
160160
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
161161
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)

examples/plot_otda_laplacian.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

ot/bregman.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16071607

16081608
# build the cost matrix and the Gibbs kernel
16091609
Mtmp = dist(Xs[d], Xt, metric=metric)
1610-
Mtmp = Mtmp / np.median(Mtmp)
16111610
M.append(Mtmp)
16121611

16131612
Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)

ot/da.py

Lines changed: 0 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -748,115 +748,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748
return A, b
749749

750750

751-
def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
752-
numItermax, stopThr, numInnerItermax,
753-
stopInnerThr, log=False, verbose=False, **kwargs):
754-
r"""Solve the optimal transport problem (OT) with Laplacian regularization
755-
756-
.. math::
757-
\gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma)
758-
759-
s.t.\ \gamma 1 = a
760-
761-
\gamma^T 1= b
762-
763-
\gamma\geq 0
764-
765-
where:
766-
767-
- a and b are source and target weights (sum to 1)
768-
- xs and xt are source and target samples
769-
- M is the (ns,nt) metric cost matrix
770-
- :math:`\Omega_\alpha` is the Laplacian regularization term
771-
:math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2`
772-
with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping
773-
774-
The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5].
775-
776-
Parameters
777-
----------
778-
a : np.ndarray (ns,)
779-
samples weights in the source domain
780-
b : np.ndarray (nt,)
781-
samples weights in the target domain
782-
xs : np.ndarray (ns,d)
783-
samples in the source domain
784-
xt : np.ndarray (nt,d)
785-
samples in the target domain
786-
M : np.ndarray (ns,nt)
787-
loss matrix
788-
eta : float
789-
Regularization term for Laplacian regularization
790-
alpha : float
791-
Regularization term for source domain's importance in regularization
792-
numItermax : int, optional
793-
Max number of iterations
794-
stopThr : float, optional
795-
Stop threshold on error (inner emd solver) (>0)
796-
numInnerItermax : int, optional
797-
Max number of iterations (inner CG solver)
798-
stopInnerThr : float, optional
799-
Stop threshold on error (inner CG solver) (>0)
800-
verbose : bool, optional
801-
Print information along iterations
802-
log : bool, optional
803-
record log if True
804-
805-
806-
Returns
807-
-------
808-
gamma : (ns x nt) ndarray
809-
Optimal transportation matrix for the given parameters
810-
log : dict
811-
log dictionary return only if log==True in parameters
812-
813-
814-
References
815-
----------
816-
817-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
818-
"Optimal Transport for Domain Adaptation," in IEEE
819-
Transactions on Pattern Analysis and Machine Intelligence ,
820-
vol.PP, no.99, pp.1-1
821-
822-
See Also
823-
--------
824-
ot.lp.emd : Unregularized OT
825-
ot.optim.cg : General regularized OT
826-
827-
"""
828-
if sim == 'gauss':
829-
if 'rbfparam' not in kwargs:
830-
kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
831-
sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam'])
832-
sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam'])
833-
834-
elif sim == 'knn':
835-
if 'nn' not in kwargs:
836-
kwargs['nn'] = 5
837-
838-
from sklearn.neighbors import kneighbors_graph
839-
840-
sS = kneighbors_graph(xs, kwargs['nn']).toarray()
841-
sS = (sS + sS.T) / 2
842-
sT = kneighbors_graph(xt, kwargs['nn']).toarray()
843-
sT = (sT + sT.T) / 2
844-
845-
lS = laplacian(sS)
846-
lT = laplacian(sT)
847-
848-
def f(G):
849-
return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
850-
+ (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
851-
852-
def df(G):
853-
return alpha * np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\
854-
+ (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T)))
855-
856-
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
857-
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
858-
859-
860751
def distribution_estimation_uniform(X):
861752
"""estimates a uniform distribution from an array of samples X
862753
@@ -1603,113 +1494,6 @@ class label
16031494
return self
16041495

16051496

1606-
class EMDLaplaceTransport(BaseTransport):
1607-
1608-
"""Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization
1609-
1610-
Parameters
1611-
----------
1612-
reg_lap : float, optional (default=1)
1613-
Laplacian regularization parameter
1614-
reg_src : float, optional (default=0.5)
1615-
Source relative importance in regularization
1616-
metric : string, optional (default="sqeuclidean")
1617-
The ground metric for the Wasserstein problem
1618-
norm : string, optional (default=None)
1619-
If given, normalize the ground metric to avoid numerical errors that
1620-
can occur with large metric values.
1621-
similarity : string, optional (default="knn")
1622-
The similarity to use either knn or gaussian
1623-
max_iter : int, optional (default=100)
1624-
Max number of BCD iterations
1625-
tol : float, optional (default=1e-5)
1626-
Stop threshold on relative loss decrease (>0)
1627-
max_inner_iter : int, optional (default=10)
1628-
Max number of iterations (inner CG solver)
1629-
inner_tol : float, optional (default=1e-6)
1630-
Stop threshold on error (inner CG solver) (>0)
1631-
log : int, optional (default=False)
1632-
Controls the logs of the optimization algorithm
1633-
distribution_estimation : callable, optional (defaults to the uniform)
1634-
The kind of distribution estimation to employ
1635-
out_of_sample_map : string, optional (default="ferradans")
1636-
The kind of out of sample mapping to apply to transport samples
1637-
from a domain into another one. Currently the only possible option is
1638-
"ferradans" which uses the method proposed in [6].
1639-
1640-
Attributes
1641-
----------
1642-
coupling_ : array-like, shape (n_source_samples, n_target_samples)
1643-
The optimal coupling
1644-
1645-
References
1646-
----------
1647-
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1648-
"Optimal Transport for Domain Adaptation," in IEEE Transactions
1649-
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1650-
"""
1651-
1652-
def __init__(self, reg_lap=1., reg_src=1., alpha=0.5,
1653-
metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9,
1654-
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
1655-
distribution_estimation=distribution_estimation_uniform,
1656-
out_of_sample_map='ferradans'):
1657-
self.reg_lap = reg_lap
1658-
self.reg_src = reg_src
1659-
self.alpha = alpha
1660-
self.metric = metric
1661-
self.norm = norm
1662-
self.similarity = similarity
1663-
self.max_iter = max_iter
1664-
self.tol = tol
1665-
self.max_inner_iter = max_inner_iter
1666-
self.inner_tol = inner_tol
1667-
self.log = log
1668-
self.verbose = verbose
1669-
self.distribution_estimation = distribution_estimation
1670-
self.out_of_sample_map = out_of_sample_map
1671-
1672-
def fit(self, Xs, ys=None, Xt=None, yt=None):
1673-
"""Build a coupling matrix from source and target sets of samples
1674-
(Xs, ys) and (Xt, yt)
1675-
1676-
Parameters
1677-
----------
1678-
Xs : array-like, shape (n_source_samples, n_features)
1679-
The training input samples.
1680-
ys : array-like, shape (n_source_samples,)
1681-
The class labels
1682-
Xt : array-like, shape (n_target_samples, n_features)
1683-
The training input samples.
1684-
yt : array-like, shape (n_target_samples,)
1685-
The class labels. If some target samples are unlabeled, fill the
1686-
yt's elements with -1.
1687-
1688-
Warning: Note that, due to this convention -1 cannot be used as a
1689-
class label
1690-
1691-
Returns
1692-
-------
1693-
self : object
1694-
Returns self.
1695-
"""
1696-
1697-
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
1698-
1699-
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
1700-
xt=self.xt_, M=self.cost_, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
1701-
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
1702-
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
1703-
1704-
# coupling estimation
1705-
if self.log:
1706-
self.coupling_, self.log_ = returned_
1707-
else:
1708-
self.coupling_ = returned_
1709-
self.log_ = dict()
1710-
return self
1711-
1712-
17131497
class SinkhornL1l2Transport(BaseTransport):
17141498

17151499
"""Domain Adapatation OT method based on sinkhorn algorithm +

0 commit comments

Comments
 (0)