Skip to content

Commit 39cbcd3

Browse files
authored
Merge pull request #52 from kilianFatras/stochastic_OT
Add semi-dual and dual stochastic optimization fro entropic regularization.
2 parents 327b0c6 + b4bc861 commit 39cbcd3

File tree

6 files changed

+1216
-1
lines changed

6 files changed

+1216
-1
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ It provides the following solvers:
2323
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
2424
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
2525
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
26+
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
2627

2728
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
2829

@@ -162,6 +163,7 @@ The contributors to this library are:
162163
* [Stanislas Chambon](https://slasnista.github.io/)
163164
* [Antoine Rolet](https://arolet.github.io/)
164165
* Erwan Vautier (Gromov-Wasserstein)
166+
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic optimization)
165167

166168
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
167169

@@ -219,3 +221,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t
219221
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
220222

221223
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
224+
225+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
226+
227+
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

docs/source/all.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ ot.smooth
2525
.. automodule:: ot.smooth
2626
:members:
2727

28+
ot.smooth
29+
-----
30+
.. automodule:: ot.smooth
31+
:members:
32+
2833
ot.gromov
2934
----------
3035

@@ -68,3 +73,9 @@ ot.plot
6873

6974
.. automodule:: ot.plot
7075
:members:
76+
77+
ot.stochastic
78+
-------------
79+
80+
.. automodule:: ot.stochastic
81+
:members:

examples/plot_stochastic.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
==========================
3+
Stochastic examples
4+
==========================
5+
6+
This example is designed to show how to use the stochatic optimization
7+
algorithms for descrete and semicontinous measures from the POT library.
8+
9+
"""
10+
11+
# Author: Kilian Fatras <kilian.fatras@gmail.com>
12+
#
13+
# License: MIT License
14+
15+
import matplotlib.pylab as pl
16+
import numpy as np
17+
import ot
18+
import ot.plot
19+
20+
21+
#############################################################################
22+
# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
23+
#############################################################################
24+
print("------------SEMI-DUAL PROBLEM------------")
25+
#############################################################################
26+
# DISCRETE CASE
27+
# Sample two discrete measures for the discrete case
28+
# ---------------------------------------------
29+
#
30+
# Define 2 discrete measures a and b, the points where are defined the source
31+
# and the target measures and finally the cost matrix c.
32+
33+
n_source = 7
34+
n_target = 4
35+
reg = 1
36+
numItermax = 1000
37+
38+
a = ot.utils.unif(n_source)
39+
b = ot.utils.unif(n_target)
40+
41+
rng = np.random.RandomState(0)
42+
X_source = rng.randn(n_source, 2)
43+
Y_target = rng.randn(n_target, 2)
44+
M = ot.dist(X_source, Y_target)
45+
46+
#############################################################################
47+
#
48+
# Call the "SAG" method to find the transportation matrix in the discrete case
49+
# ---------------------------------------------
50+
#
51+
# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
52+
# results.
53+
54+
method = "SAG"
55+
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
56+
numItermax)
57+
print(sag_pi)
58+
59+
#############################################################################
60+
# SEMICONTINOUS CASE
61+
# Sample one general measure a, one discrete measures b for the semicontinous
62+
# case
63+
# ---------------------------------------------
64+
#
65+
# Define one general measure a, one discrete measures b, the points where
66+
# are defined the source and the target measures and finally the cost matrix c.
67+
68+
n_source = 7
69+
n_target = 4
70+
reg = 1
71+
numItermax = 1000
72+
log = True
73+
74+
a = ot.utils.unif(n_source)
75+
b = ot.utils.unif(n_target)
76+
77+
rng = np.random.RandomState(0)
78+
X_source = rng.randn(n_source, 2)
79+
Y_target = rng.randn(n_target, 2)
80+
M = ot.dist(X_source, Y_target)
81+
82+
#############################################################################
83+
#
84+
# Call the "ASGD" method to find the transportation matrix in the semicontinous
85+
# case
86+
# ---------------------------------------------
87+
#
88+
# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
89+
# results.
90+
91+
method = "ASGD"
92+
asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
93+
numItermax, log=log)
94+
print(log_asgd['alpha'], log_asgd['beta'])
95+
print(asgd_pi)
96+
97+
#############################################################################
98+
#
99+
# Compare the results with the Sinkhorn algorithm
100+
# ---------------------------------------------
101+
#
102+
# Call the Sinkhorn algorithm from POT
103+
104+
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
105+
print(sinkhorn_pi)
106+
107+
108+
##############################################################################
109+
# PLOT TRANSPORTATION MATRIX
110+
##############################################################################
111+
112+
##############################################################################
113+
# Plot SAG results
114+
# ----------------
115+
116+
pl.figure(4, figsize=(5, 5))
117+
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
118+
pl.show()
119+
120+
121+
##############################################################################
122+
# Plot ASGD results
123+
# -----------------
124+
125+
pl.figure(4, figsize=(5, 5))
126+
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
127+
pl.show()
128+
129+
130+
##############################################################################
131+
# Plot Sinkhorn results
132+
# ---------------------
133+
134+
pl.figure(4, figsize=(5, 5))
135+
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
136+
pl.show()
137+
138+
139+
#############################################################################
140+
# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
141+
#############################################################################
142+
print("------------DUAL PROBLEM------------")
143+
#############################################################################
144+
# SEMICONTINOUS CASE
145+
# Sample one general measure a, one discrete measures b for the semicontinous
146+
# case
147+
# ---------------------------------------------
148+
#
149+
# Define one general measure a, one discrete measures b, the points where
150+
# are defined the source and the target measures and finally the cost matrix c.
151+
152+
n_source = 7
153+
n_target = 4
154+
reg = 1
155+
numItermax = 100000
156+
lr = 0.1
157+
batch_size = 3
158+
log = True
159+
160+
a = ot.utils.unif(n_source)
161+
b = ot.utils.unif(n_target)
162+
163+
rng = np.random.RandomState(0)
164+
X_source = rng.randn(n_source, 2)
165+
Y_target = rng.randn(n_target, 2)
166+
M = ot.dist(X_source, Y_target)
167+
168+
#############################################################################
169+
#
170+
# Call the "SGD" dual method to find the transportation matrix in the
171+
# semicontinous case
172+
# ---------------------------------------------
173+
#
174+
# Call ot.solve_dual_entropic and plot the results.
175+
176+
sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
177+
batch_size, numItermax,
178+
lr, log=log)
179+
print(log_sgd['alpha'], log_sgd['beta'])
180+
print(sgd_dual_pi)
181+
182+
#############################################################################
183+
#
184+
# Compare the results with the Sinkhorn algorithm
185+
# ---------------------------------------------
186+
#
187+
# Call the Sinkhorn algorithm from POT
188+
189+
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
190+
print(sinkhorn_pi)
191+
192+
##############################################################################
193+
# Plot SGD results
194+
# -----------------
195+
196+
pl.figure(4, figsize=(5, 5))
197+
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
198+
pl.show()
199+
200+
201+
##############################################################################
202+
# Plot Sinkhorn results
203+
# ---------------------
204+
205+
pl.figure(4, figsize=(5, 5))
206+
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
207+
pl.show()

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from . import da
2020
from . import gromov
2121
from . import smooth
22-
22+
from . import stochastic
2323

2424
# OT functions
2525
from .lp import emd, emd2

0 commit comments

Comments
 (0)