Skip to content

Commit a2545b5

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
add empirical sinkhorn and sikhorn divergence functions
1 parent 2384380 commit a2545b5

File tree

4 files changed

+354
-0
lines changed

4 files changed

+354
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
230230
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
231231

232232
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
233+
234+
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018

examples/plot_OT_2D_samples.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
# Author: Remi Flamary <remi.flamary@unice.fr>
13+
# Kilian Fatras <kilian.fatras@irisa.fr>
1314
#
1415
# License: MIT License
1516

@@ -100,3 +101,28 @@
100101
pl.title('OT matrix Sinkhorn with samples')
101102

102103
pl.show()
104+
105+
106+
##############################################################################
107+
# Emprirical Sinkhorn
108+
# ----------------
109+
110+
#%% sinkhorn
111+
112+
# reg term
113+
lambd = 1e-3
114+
115+
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
116+
117+
pl.figure(7)
118+
pl.imshow(Ges, interpolation='nearest')
119+
pl.title('OT matrix empirical sinkhorn')
120+
121+
pl.figure(8)
122+
ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
123+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
124+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
125+
pl.legend(loc=0)
126+
pl.title('OT matrix Sinkhorn from samples')
127+
128+
pl.show()

ot/bregman.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# Author: Remi Flamary <remi.flamary@unice.fr>
77
# Nicolas Courty <ncourty@irisa.fr>
8+
# Kilian Fatras <kilian.fatras@irisa.fr>
89
#
910
# License: MIT License
1011

@@ -1296,3 +1297,271 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12961297
return np.sum(K0, axis=1), log
12971298
else:
12981299
return np.sum(K0, axis=1)
1300+
1301+
1302+
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1303+
'''
1304+
Solve the entropic regularization optimal transport problem and return the
1305+
OT matrix from empirical data
1306+
1307+
The function solves the following optimization problem:
1308+
1309+
.. math::
1310+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1311+
1312+
s.t. \gamma 1 = a
1313+
1314+
\gamma^T 1= b
1315+
1316+
\gamma\geq 0
1317+
where :
1318+
1319+
- M is the (ns,nt) metric cost matrix
1320+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1321+
- a and b are source and target weights (sum to 1)
1322+
1323+
1324+
Parameters
1325+
----------
1326+
X_s : np.ndarray (ns, d)
1327+
samples in the source domain
1328+
X_t : np.ndarray (nt, d)
1329+
samples in the target domain
1330+
reg : float
1331+
Regularization term >0
1332+
a : np.ndarray (ns,)
1333+
samples weights in the source domain
1334+
b : np.ndarray (nt,)
1335+
samples weights in the target domain
1336+
numItermax : int, optional
1337+
Max number of iterations
1338+
stopThr : float, optional
1339+
Stop threshol on error (>0)
1340+
verbose : bool, optional
1341+
Print information along iterations
1342+
log : bool, optional
1343+
record log if True
1344+
1345+
1346+
Returns
1347+
-------
1348+
gamma : (ns x nt) ndarray
1349+
Regularized optimal transportation matrix for the given parameters
1350+
log : dict
1351+
log dictionary return only if log==True in parameters
1352+
1353+
Examples
1354+
--------
1355+
1356+
>>> n_s = 2
1357+
>>> n_t = 2
1358+
>>> reg = 0.1
1359+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1360+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1361+
>>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
1362+
>>> print(emp_sinkhorn)
1363+
>>> [[4.99977301e-01 2.26989344e-05]
1364+
[2.26989344e-05 4.99977301e-01]]
1365+
1366+
1367+
References
1368+
----------
1369+
1370+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
1371+
1372+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
1373+
1374+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
1375+
'''
1376+
1377+
if a is None:
1378+
a = ot.unif(np.shape(X_s)[0])
1379+
if b is None:
1380+
b = ot.unif(np.shape(X_t)[0])
1381+
M = ot.dist(X_s, X_t, metric=metric)
1382+
if log == False:
1383+
pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1384+
return pi
1385+
1386+
if log == True:
1387+
pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
1388+
return pi, log
1389+
1390+
1391+
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1392+
'''
1393+
Solve the entropic regularization optimal transport problem from empirical
1394+
data and return the OT loss
1395+
1396+
1397+
The function solves the following optimization problem:
1398+
1399+
.. math::
1400+
W = \min_\gamma_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1401+
1402+
s.t. \gamma 1 = a
1403+
1404+
\gamma^T 1= b
1405+
1406+
\gamma\geq 0
1407+
where :
1408+
1409+
- M is the (ns,nt) metric cost matrix
1410+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1411+
- a and b are source and target weights (sum to 1)
1412+
1413+
1414+
Parameters
1415+
----------
1416+
X_s : np.ndarray (ns, d)
1417+
samples in the source domain
1418+
X_t : np.ndarray (nt, d)
1419+
samples in the target domain
1420+
reg : float
1421+
Regularization term >0
1422+
a : np.ndarray (ns,)
1423+
samples weights in the source domain
1424+
b : np.ndarray (nt,)
1425+
samples weights in the target domain
1426+
numItermax : int, optional
1427+
Max number of iterations
1428+
stopThr : float, optional
1429+
Stop threshol on error (>0)
1430+
verbose : bool, optional
1431+
Print information along iterations
1432+
log : bool, optional
1433+
record log if True
1434+
1435+
1436+
Returns
1437+
-------
1438+
gamma : (ns x nt) ndarray
1439+
Regularized optimal transportation matrix for the given parameters
1440+
log : dict
1441+
log dictionary return only if log==True in parameters
1442+
1443+
Examples
1444+
--------
1445+
1446+
>>> n_s = 2
1447+
>>> n_t = 2
1448+
>>> reg = 0.1
1449+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1450+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1451+
>>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
1452+
>>> print(loss_sinkhorn)
1453+
>>> [4.53978687e-05]
1454+
1455+
1456+
References
1457+
----------
1458+
1459+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
1460+
1461+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
1462+
1463+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
1464+
'''
1465+
1466+
if a is None:
1467+
a = ot.unif(np.shape(X_s)[0])
1468+
if b is None:
1469+
b = ot.unif(np.shape(X_t)[0])
1470+
1471+
M = ot.dist(X_s, X_t, metric=metric)
1472+
if log == False:
1473+
sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1474+
return sinkhorn_loss
1475+
1476+
if log == True:
1477+
sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1478+
return sinkhorn_loss, log
1479+
1480+
1481+
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1482+
'''
1483+
Compute the sinkhorn divergence loss from empirical data
1484+
1485+
The function solves the following optimization problem:
1486+
1487+
.. math::
1488+
S = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) -
1489+
\min_\gamma_a <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) -
1490+
\min_\gamma_b <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
1491+
1492+
s.t. \gamma 1 = a
1493+
1494+
\gamma^T 1= b
1495+
1496+
\gamma\geq 0
1497+
1498+
\gamma_a 1 = a
1499+
1500+
\gamma_a^T 1= a
1501+
1502+
\gamma_a\geq 0
1503+
1504+
\gamma_b 1 = b
1505+
1506+
\gamma_b^T 1= b
1507+
1508+
\gamma_b\geq 0
1509+
where :
1510+
1511+
- M (resp. :math:`M_a, M_b) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
1512+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1513+
- a and b are source and target weights (sum to 1)
1514+
1515+
1516+
Parameters
1517+
----------
1518+
X_s : np.ndarray (ns, d)
1519+
samples in the source domain
1520+
X_t : np.ndarray (nt, d)
1521+
samples in the target domain
1522+
reg : float
1523+
Regularization term >0
1524+
a : np.ndarray (ns,)
1525+
samples weights in the source domain
1526+
b : np.ndarray (nt,)
1527+
samples weights in the target domain
1528+
numItermax : int, optional
1529+
Max number of iterations
1530+
stopThr : float, optional
1531+
Stop threshol on error (>0)
1532+
verbose : bool, optional
1533+
Print information along iterations
1534+
log : bool, optional
1535+
record log if True
1536+
1537+
1538+
Returns
1539+
-------
1540+
gamma : (ns x nt) ndarray
1541+
Regularized optimal transportation matrix for the given parameters
1542+
log : dict
1543+
log dictionary return only if log==True in parameters
1544+
1545+
Examples
1546+
--------
1547+
1548+
>>> n_s = 2
1549+
>>> n_t = 4
1550+
>>> reg = 0.1
1551+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1552+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1553+
>>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
1554+
>>> print(emp_sinkhorn_div)
1555+
>>> [2.99977435]
1556+
1557+
1558+
References
1559+
----------
1560+
1561+
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
1562+
'''
1563+
1564+
sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1565+
empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1566+
empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
1567+
return max(0, sinkhorn_div)

test/test_bregman.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for module bregman on OT with bregman projections """
22

33
# Author: Remi Flamary <remi.flamary@unice.fr>
4+
# Kilian Fatras <kilian.fatras@irisa.fr>
45
#
56
# License: MIT License
67

@@ -187,3 +188,59 @@ def test_unmix():
187188

188189
ot.bregman.unmix(a, D, M, M0, h0, reg,
189190
1, alpha=0.01, log=True, verbose=True)
191+
192+
193+
def test_empirical_sinkhorn():
194+
# test sinkhorn
195+
n = 100
196+
a = ot.unif(n)
197+
b = ot.unif(n)
198+
M = ot.dist(X_s, X_t)
199+
M_e = ot.dist(X_s, X_t, metric='euclidean')
200+
201+
rng = np.random.RandomState(0)
202+
203+
X_s = np.reshape(np.arange(n), (n, 1))
204+
X_t = np.reshape(np.arange(0, n), (n, 1))
205+
206+
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
207+
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
208+
209+
G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
210+
sinkhorn_e = ot.sinkhorn(a, b, M_e, 1)
211+
212+
loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
213+
loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
214+
215+
# check constratints
216+
np.testing.assert_allclose(
217+
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
218+
np.testing.assert_allclose(
219+
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
220+
np.testing.assert_allclose(
221+
sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian
222+
np.testing.assert_allclose(
223+
sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian
224+
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
225+
226+
227+
def test_empirical_sinkhorn_divergence():
228+
#Test sinkhorn divergence
229+
n = 10
230+
a = ot.unif(n)
231+
b = ot.unif(n)
232+
X_s = np.reshape(np.arange(n), (n, 1))
233+
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
234+
M = ot.dist(X_s, X_t)
235+
M_s = ot.dist(X_s, X_s)
236+
M_t = ot.dist(X_t, X_t)
237+
238+
emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1)
239+
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
240+
ot.sinkhorn2(b, b, M_t, 1))
241+
242+
# check constratints
243+
np.testing.assert_allclose(
244+
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
245+
np.testing.assert_allclose(
246+
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn

0 commit comments

Comments
 (0)