@@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
5656 >>> n_target = 4
5757 >>> reg = 1
5858 >>> numItermax = 300000
59- >>> lr = 1
6059 >>> a = ot.utils.unif(n_source)
6160 >>> b = ot.utils.unif(n_target)
6261 >>> rng = np.random.RandomState(0)
@@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
6564 >>> M = ot.dist(X_source, Y_target)
6665 >>> method = "ASGD"
6766 >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
68- method, numItermax,
69- lr)
67+ method, numItermax)
7068 >>> print(asgd_pi)
7169
7270 References
@@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
8583 return b - khi
8684
8785
88- def sag_entropic_transport (a , b , M , reg , numItermax = 10000 , lr = 0.1 ):
86+ def sag_entropic_transport (a , b , M , reg , numItermax = 10000 , lr = None ):
8987 '''
9088 Compute the SAG algorithm to solve the regularized discrete measures
9189 optimal transport max problem
@@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
134132 >>> n_target = 4
135133 >>> reg = 1
136134 >>> numItermax = 300000
137- >>> lr = 1
138135 >>> a = ot.utils.unif(n_source)
139136 >>> b = ot.utils.unif(n_target)
140137 >>> rng = np.random.RandomState(0)
141138 >>> X_source = rng.randn(n_source, 2)
142139 >>> Y_target = rng.randn(n_target, 2)
143140 >>> M = ot.dist(X_source, Y_target)
144- >>> method = "SAG"
145- >>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
146- method, numItermax,
147- lr)
141+ >>> method = "ASGD"
142+ >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
143+ method, numItermax)
148144 >>> print(asgd_pi)
149145
150146 References
@@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
156152 arXiv preprint arxiv:1605.08527.
157153 '''
158154
155+ if lr is None :
156+ lr = 1. / max (a )
159157 n_source = np .shape (M )[0 ]
160158 n_target = np .shape (M )[1 ]
161159 cur_beta = np .zeros (n_target )
@@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
171169 return cur_beta
172170
173171
174- def averaged_sgd_entropic_transport (b , M , reg , numItermax = 300000 , lr = 1 ):
172+ def averaged_sgd_entropic_transport (a , b , M , reg , numItermax = 300000 , lr = None ):
175173 '''
176174 Compute the ASGD algorithm to solve the regularized semi contibous measures
177175 optimal transport max problem
@@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
219217 >>> n_target = 4
220218 >>> reg = 1
221219 >>> numItermax = 300000
222- >>> lr = 1
223220 >>> a = ot.utils.unif(n_source)
224221 >>> b = ot.utils.unif(n_target)
225222 >>> rng = np.random.RandomState(0)
@@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
228225 >>> M = ot.dist(X_source, Y_target)
229226 >>> method = "ASGD"
230227 >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
231- method, numItermax,
232- lr)
228+ method, numItermax)
233229 >>> print(asgd_pi)
234230
235231 References
@@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
241237 arXiv preprint arxiv:1605.08527.
242238 '''
243239
240+ if lr is None :
241+ lr = 1. / max (a )
244242 n_source = np .shape (M )[0 ]
245243 n_target = np .shape (M )[1 ]
246244 cur_beta = np .zeros (n_target )
@@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta):
296294 >>> n_target = 4
297295 >>> reg = 1
298296 >>> numItermax = 300000
299- >>> lr = 1
300297 >>> a = ot.utils.unif(n_source)
301298 >>> b = ot.utils.unif(n_target)
302299 >>> rng = np.random.RandomState(0)
@@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta):
305302 >>> M = ot.dist(X_source, Y_target)
306303 >>> method = "ASGD"
307304 >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
308- method, numItermax,
309- lr)
305+ method, numItermax)
310306 >>> print(asgd_pi)
311307
312308 References
@@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta):
328324 return alpha
329325
330326
331- def solve_semi_dual_entropic (a , b , M , reg , method , numItermax = 10000 , lr = 0.1 ,
327+ def solve_semi_dual_entropic (a , b , M , reg , method , numItermax = 10000 , lr = None ,
332328 log = False ):
333329 '''
334330 Compute the transportation matrix to solve the regularized discrete
@@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
388384 >>> n_target = 4
389385 >>> reg = 1
390386 >>> numItermax = 300000
391- >>> lr = 1
392387 >>> a = ot.utils.unif(n_source)
393388 >>> b = ot.utils.unif(n_target)
394389 >>> rng = np.random.RandomState(0)
@@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
397392 >>> M = ot.dist(X_source, Y_target)
398393 >>> method = "ASGD"
399394 >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
400- method, numItermax,
401- lr)
395+ method, numItermax)
402396 >>> print(asgd_pi)
403397
404398 References
@@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
409403 Advances in Neural Information Processing Systems (2016),
410404 arXiv preprint arxiv:1605.08527.
411405 '''
406+
412407 if method .lower () == "sag" :
413408 opt_beta = sag_entropic_transport (a , b , M , reg , numItermax , lr )
414409 elif method .lower () == "asgd" :
415- opt_beta = averaged_sgd_entropic_transport (b , M , reg , numItermax , lr )
410+ opt_beta = averaged_sgd_entropic_transport (a , b , M , reg , numItermax , lr )
416411 else :
417412 print ("Please, select your method between SAG and ASGD" )
418413 return None
0 commit comments