Skip to content

Commit 6777ffd

Browse files
author
Kilian Fatras
committed
gave better step size ASGD & SAG
1 parent 7073e41 commit 6777ffd

File tree

3 files changed

+22
-32
lines changed

3 files changed

+22
-32
lines changed

examples/plot_stochastic.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
n_source = 7
3333
n_target = 4
3434
reg = 1
35-
numItermax = 10000
36-
lr = 0.1
35+
numItermax = 1000
3736

3837
a = ot.utils.unif(n_source)
3938
b = ot.utils.unif(n_target)
@@ -53,7 +52,7 @@
5352

5453
method = "SAG"
5554
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
56-
numItermax, lr)
55+
numItermax)
5756
print(sag_pi)
5857

5958
#############################################################################
@@ -68,8 +67,7 @@
6867
n_source = 7
6968
n_target = 4
7069
reg = 1
71-
numItermax = 100000
72-
lr = 1
70+
numItermax = 1000
7371
log = True
7472

7573
a = ot.utils.unif(n_source)
@@ -91,7 +89,7 @@
9189

9290
method = "ASGD"
9391
asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
94-
numItermax, lr, log)
92+
numItermax, log)
9593
print(log['alpha'], log['beta'])
9694
print(asgd_pi)
9795

ot/stochastic.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
5656
>>> n_target = 4
5757
>>> reg = 1
5858
>>> numItermax = 300000
59-
>>> lr = 1
6059
>>> a = ot.utils.unif(n_source)
6160
>>> b = ot.utils.unif(n_target)
6261
>>> rng = np.random.RandomState(0)
@@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
6564
>>> M = ot.dist(X_source, Y_target)
6665
>>> method = "ASGD"
6766
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
68-
method, numItermax,
69-
lr)
67+
method, numItermax)
7068
>>> print(asgd_pi)
7169
7270
References
@@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
8583
return b - khi
8684

8785

88-
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
86+
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
8987
'''
9088
Compute the SAG algorithm to solve the regularized discrete measures
9189
optimal transport max problem
@@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
134132
>>> n_target = 4
135133
>>> reg = 1
136134
>>> numItermax = 300000
137-
>>> lr = 1
138135
>>> a = ot.utils.unif(n_source)
139136
>>> b = ot.utils.unif(n_target)
140137
>>> rng = np.random.RandomState(0)
141138
>>> X_source = rng.randn(n_source, 2)
142139
>>> Y_target = rng.randn(n_target, 2)
143140
>>> M = ot.dist(X_source, Y_target)
144-
>>> method = "SAG"
145-
>>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
146-
method, numItermax,
147-
lr)
141+
>>> method = "ASGD"
142+
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
143+
method, numItermax)
148144
>>> print(asgd_pi)
149145
150146
References
@@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
156152
arXiv preprint arxiv:1605.08527.
157153
'''
158154

155+
if lr is None:
156+
lr = 1. / max(a)
159157
n_source = np.shape(M)[0]
160158
n_target = np.shape(M)[1]
161159
cur_beta = np.zeros(n_target)
@@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
171169
return cur_beta
172170

173171

174-
def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
172+
def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
175173
'''
176174
Compute the ASGD algorithm to solve the regularized semi contibous measures
177175
optimal transport max problem
@@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
219217
>>> n_target = 4
220218
>>> reg = 1
221219
>>> numItermax = 300000
222-
>>> lr = 1
223220
>>> a = ot.utils.unif(n_source)
224221
>>> b = ot.utils.unif(n_target)
225222
>>> rng = np.random.RandomState(0)
@@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
228225
>>> M = ot.dist(X_source, Y_target)
229226
>>> method = "ASGD"
230227
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
231-
method, numItermax,
232-
lr)
228+
method, numItermax)
233229
>>> print(asgd_pi)
234230
235231
References
@@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
241237
arXiv preprint arxiv:1605.08527.
242238
'''
243239

240+
if lr is None:
241+
lr = 1. / max(a)
244242
n_source = np.shape(M)[0]
245243
n_target = np.shape(M)[1]
246244
cur_beta = np.zeros(n_target)
@@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta):
296294
>>> n_target = 4
297295
>>> reg = 1
298296
>>> numItermax = 300000
299-
>>> lr = 1
300297
>>> a = ot.utils.unif(n_source)
301298
>>> b = ot.utils.unif(n_target)
302299
>>> rng = np.random.RandomState(0)
@@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta):
305302
>>> M = ot.dist(X_source, Y_target)
306303
>>> method = "ASGD"
307304
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
308-
method, numItermax,
309-
lr)
305+
method, numItermax)
310306
>>> print(asgd_pi)
311307
312308
References
@@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta):
328324
return alpha
329325

330326

331-
def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
327+
def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
332328
log=False):
333329
'''
334330
Compute the transportation matrix to solve the regularized discrete
@@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
388384
>>> n_target = 4
389385
>>> reg = 1
390386
>>> numItermax = 300000
391-
>>> lr = 1
392387
>>> a = ot.utils.unif(n_source)
393388
>>> b = ot.utils.unif(n_target)
394389
>>> rng = np.random.RandomState(0)
@@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
397392
>>> M = ot.dist(X_source, Y_target)
398393
>>> method = "ASGD"
399394
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
400-
method, numItermax,
401-
lr)
395+
method, numItermax)
402396
>>> print(asgd_pi)
403397
404398
References
@@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
409403
Advances in Neural Information Processing Systems (2016),
410404
arXiv preprint arxiv:1605.08527.
411405
'''
406+
412407
if method.lower() == "sag":
413408
opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr)
414409
elif method.lower() == "asgd":
415-
opt_beta = averaged_sgd_entropic_transport(b, M, reg, numItermax, lr)
410+
opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr)
416411
else:
417412
print("Please, select your method between SAG and ASGD")
418413
return None

test/test_stochastic.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def test_stochastic_asgd():
6363
n = 15
6464
reg = 1
6565
numItermax = 300000
66-
lr = 1
6766
rng = np.random.RandomState(0)
6867

6968
x = rng.randn(n, 2)
@@ -72,8 +71,7 @@ def test_stochastic_asgd():
7271
M = ot.dist(x, x)
7372

7473
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
75-
numItermax=numItermax,
76-
lr=lr)
74+
numItermax=numItermax)
7775

7876
# check constratints
7977
np.testing.assert_allclose(
@@ -95,7 +93,6 @@ def test_sag_asgd_sinkhorn():
9593
n = 15
9694
reg = 1
9795
nb_iter = 300000
98-
lr = 1
9996
rng = np.random.RandomState(0)
10097

10198
x = rng.randn(n, 2)
@@ -104,7 +101,7 @@ def test_sag_asgd_sinkhorn():
104101
M = ot.dist(x, x)
105102

106103
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
107-
numItermax=nb_iter, lr=lr)
104+
numItermax=nb_iter)
108105
G_sag = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
109106
numItermax=nb_iter)
110107
G_sinkhorn = ot.sinkhorn(u, u, M, reg)

0 commit comments

Comments
 (0)