Skip to content

Commit 74cfe5a

Browse files
author
Kilian Fatras
committed
add sgd
1 parent 055417e commit 74cfe5a

File tree

4 files changed

+608
-67
lines changed

4 files changed

+608
-67
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ It provides the following solvers:
2222
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
2323
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
2424
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
25+
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
2526

2627
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
2728

@@ -149,6 +150,7 @@ The contributors to this library are:
149150
* [Stanislas Chambon](https://slasnista.github.io/)
150151
* [Antoine Rolet](https://arolet.github.io/)
151152
* Erwan Vautier (Gromov-Wasserstein)
153+
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic optimization)
152154

153155
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
154156

@@ -213,3 +215,9 @@ You can also post bug reports and feature requests in Github issues. Make sure t
213215
[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .
214216

215217
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
218+
219+
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/pdf/1710.06276.pdf). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
220+
221+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
222+
223+
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

examples/plot_stochastic.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919

2020
#############################################################################
21-
# COMPUTE TRANSPORTATION MATRIX
21+
# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
2222
#############################################################################
23-
23+
print("------------SEMI-DUAL PROBLEM------------")
2424
#############################################################################
2525
# DISCRETE CASE
2626
# Sample two discrete measures for the discrete case
@@ -48,12 +48,12 @@
4848
# Call the "SAG" method to find the transportation matrix in the discrete case
4949
# ---------------------------------------------
5050
#
51-
# Define the method "SAG", call ot.transportation_matrix_entropic and plot the
51+
# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
5252
# results.
5353

5454
method = "SAG"
55-
sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
56-
numItermax, lr)
55+
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
56+
numItermax, lr)
5757
print(sag_pi)
5858

5959
#############################################################################
@@ -68,8 +68,9 @@
6868
n_source = 7
6969
n_target = 4
7070
reg = 1
71-
numItermax = 500000
71+
numItermax = 100000
7272
lr = 1
73+
log = True
7374

7475
a = ot.utils.unif(n_source)
7576
b = ot.utils.unif(n_target)
@@ -85,12 +86,13 @@
8586
# case
8687
# ---------------------------------------------
8788
#
88-
# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the
89+
# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
8990
# results.
9091

9192
method = "ASGD"
92-
asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
93-
numItermax, lr)
93+
asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
94+
numItermax, lr, log)
95+
print(log['alpha'], log['beta'])
9496
print(asgd_pi)
9597

9698
#############################################################################
@@ -100,7 +102,7 @@
100102
#
101103
# Call the Sinkhorn algorithm from POT
102104

103-
sinkhorn_pi = ot.sinkhorn(a, b, M, 1)
105+
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
104106
print(sinkhorn_pi)
105107

106108

@@ -113,7 +115,7 @@
113115
# ----------------
114116

115117
pl.figure(4, figsize=(5, 5))
116-
ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG')
118+
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
117119
pl.show()
118120

119121

@@ -122,7 +124,76 @@
122124
# -----------------
123125

124126
pl.figure(4, figsize=(5, 5))
125-
ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD')
127+
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
128+
pl.show()
129+
130+
131+
##############################################################################
132+
# Plot Sinkhorn results
133+
# ---------------------
134+
135+
pl.figure(4, figsize=(5, 5))
136+
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
137+
pl.show()
138+
139+
#############################################################################
140+
# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
141+
#############################################################################
142+
print("------------DUAL PROBLEM------------")
143+
#############################################################################
144+
# SEMICONTINOUS CASE
145+
# Sample one general measure a, one discrete measures b for the semicontinous
146+
# case
147+
# ---------------------------------------------
148+
#
149+
# Define one general measure a, one discrete measures b, the points where
150+
# are defined the source and the target measures and finally the cost matrix c.
151+
152+
n_source = 7
153+
n_target = 4
154+
reg = 1
155+
numItermax = 100000
156+
lr = 0.1
157+
batch_size = 3
158+
log = True
159+
160+
a = ot.utils.unif(n_source)
161+
b = ot.utils.unif(n_target)
162+
163+
rng = np.random.RandomState(0)
164+
X_source = rng.randn(n_source, 2)
165+
Y_target = rng.randn(n_target, 2)
166+
M = ot.dist(X_source, Y_target)
167+
168+
#############################################################################
169+
#
170+
# Call the "SGD" dual method to find the transportation matrix in the semicontinous
171+
# case
172+
# ---------------------------------------------
173+
#
174+
# Call ot.solve_dual_entropic and plot the results.
175+
176+
sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
177+
numItermax, lr, log)
178+
print(log['alpha'], log['beta'])
179+
print(sgd_dual_pi)
180+
181+
#############################################################################
182+
#
183+
# Compare the results with the Sinkhorn algorithm
184+
# ---------------------------------------------
185+
#
186+
# Call the Sinkhorn algorithm from POT
187+
188+
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
189+
print(sinkhorn_pi)
190+
191+
##############################################################################
192+
# Plot SGD results
193+
# -----------------
194+
195+
pl.figure(4, figsize=(5, 5))
196+
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
126197
pl.show()
127198

128199

0 commit comments

Comments
 (0)