Skip to content

Commit d7c709e

Browse files
jakubzadroznyJakub Zadrożnyrflamary
authored
[MRG] Implement Sinkhorn in log-domain for WDA (#336)
* [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm * Add feature to RELEASES and contributor name * Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user Co-authored-by: Jakub Zadrożny <jz@qed.ai> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 263c584 commit d7c709e

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#### New features
66

77
- Better list of related examples in quick start guide with `minigallery` (PR #334)
8+
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
9+
of the regularization parameter (PR #336)
810

911
#### Closed issues
1012

ot/dr.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# Author: Remi Flamary <remi.flamary@unice.fr>
1313
# Minhui Huang <mhhuang@ucdavis.edu>
14+
# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
1415
#
1516
# License: MIT License
1617

@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
4344
return G
4445

4546

47+
def logsumexp(M, axis):
48+
r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
49+
"""
50+
amax = np.amax(M, axis=axis, keepdims=True)
51+
return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
52+
53+
54+
def sinkhorn_log(w1, w2, M, reg, k):
55+
r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
56+
"""
57+
Mr = -M / reg
58+
ui = np.zeros((M.shape[0],))
59+
vi = np.zeros((M.shape[1],))
60+
log_w1 = np.log(w1)
61+
log_w2 = np.log(w2)
62+
for i in range(k):
63+
vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
64+
ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
65+
G = np.exp(ui[:, None] + Mr + vi[None, :])
66+
return G
67+
68+
4669
def split_classes(X, y):
4770
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
4871
"""
@@ -110,7 +133,7 @@ def proj(X):
110133
return Popt, proj
111134

112135

113-
def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
136+
def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
114137
r"""
115138
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
116139
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
126149
- :math:`W` is entropic regularized Wasserstein distances
127150
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
128151
152+
**Choosing a Sinkhorn solver**
153+
154+
By default and when using a regularization parameter that is not too small
155+
the default sinkhorn solver should be enough. If you need to use a small
156+
regularization to get sparse cost matrices, you should use the
157+
:py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
158+
errors, but can be slow in practice.
159+
129160
Parameters
130161
----------
131162
X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
139170
solver : None | str, optional
140171
None for steepest descent or 'TrustRegions' for trust regions algorithm
141172
else should be a pymanopt.solvers
173+
sinkhorn_method : str
174+
method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
142175
P0 : ndarray, shape (d, p)
143176
Initial starting point for projection.
144177
normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
161194
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
162195
""" # noqa
163196

197+
if sinkhorn_method.lower() == 'sinkhorn':
198+
sinkhorn_solver = sinkhorn
199+
elif sinkhorn_method.lower() == 'sinkhorn_log':
200+
sinkhorn_solver = sinkhorn_log
201+
else:
202+
raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)
203+
164204
mx = np.mean(X)
165205
X -= mx.reshape((1, -1))
166206

@@ -193,7 +233,7 @@ def cost(P):
193233
for j, xj in enumerate(xc[i:]):
194234
xj = np.dot(xj, P)
195235
M = dist(xi, xj)
196-
G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
236+
G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
197237
if j == 0:
198238
loss_w += np.sum(G * M)
199239
else:

test/test_dr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,28 @@ def test_wda():
6060
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
6161

6262

63+
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
64+
def test_wda_low_reg():
65+
66+
n_samples = 100 # nb samples in source and target datasets
67+
np.random.seed(0)
68+
69+
# generate gaussian dataset
70+
xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
71+
72+
n_features_noise = 8
73+
74+
xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
75+
76+
p = 2
77+
78+
Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log')
79+
80+
projwda(xs)
81+
82+
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
83+
84+
6385
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
6486
def test_wda_normalized():
6587

0 commit comments

Comments
 (0)