Skip to content

Commit 52134e9

Browse files
author
Kilian Fatras
committed
change grad function names
1 parent e068b58 commit 52134e9

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

docs/source/all.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ ot.optim
3333
.. automodule:: ot.optim
3434
:members:
3535

36+
37+
ot.stochastic
38+
--------
39+
40+
.. automodule:: ot.stochastic
41+
:members:
42+
43+
3644
ot.da
3745
--------
3846

ot/stochastic.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
##############################################################################
9-
# Optimization toolbox for SEMI - DUAL problem
9+
# Optimization toolbox for SEMI - DUAL problems
1010
##############################################################################
11-
def coordinate_gradient(b, M, reg, beta, i):
11+
def coordinate_grad_semi_dual(b, M, reg, beta, i):
1212
'''
1313
Compute the coordinate gradient update for regularized discrete
1414
distributions for (i, :)
@@ -161,7 +161,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
161161
sum_stored_gradient = np.zeros(n_target)
162162
for _ in range(numItermax):
163163
i = np.random.randint(n_source)
164-
cur_coord_grad = a[i] * coordinate_gradient(b, M, reg, cur_beta, i)
164+
cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg,
165+
cur_beta, i)
165166
sum_stored_gradient += (cur_coord_grad - stored_gradient[i])
166167
stored_gradient[i] = cur_coord_grad
167168
cur_beta += lr * (1. / n_source) * sum_stored_gradient
@@ -245,7 +246,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
245246
for cur_iter in range(numItermax):
246247
k = cur_iter + 1
247248
i = np.random.randint(n_source)
248-
cur_coord_grad = coordinate_gradient(b, M, reg, cur_beta, i)
249+
cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i)
249250
cur_beta += (lr / np.sqrt(k)) * cur_coord_grad
250251
ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta
251252
return ave_beta
@@ -428,11 +429,12 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
428429

429430

430431
##############################################################################
431-
# Optimization toolbox for DUAL problem
432+
# Optimization toolbox for DUAL problems
432433
##############################################################################
433434

434435

435-
def grad_dF_dalpha(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta):
436+
def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
437+
batch_beta):
436438
'''
437439
Computes the partial gradient of F_\W_varepsilon
438440
@@ -513,7 +515,8 @@ def grad_dF_dalpha(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta):
513515
return grad_alpha
514516

515517

516-
def grad_dF_dbeta(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta):
518+
def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
519+
batch_beta):
517520
'''
518521
Computes the partial gradient of F_\W_varepsilon
519522
@@ -676,22 +679,26 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
676679
k = np.sqrt(cur_iter + 1)
677680
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
678681
batch_beta = np.random.choice(n_target, batch_size, replace=False)
679-
grad_F_alpha = grad_dF_dalpha(M, reg, cur_alpha, cur_beta,
680-
batch_size, batch_alpha, batch_beta)
682+
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
683+
batch_size, batch_alpha,
684+
batch_beta)
681685
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
682-
grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta,
683-
batch_size, batch_alpha, batch_beta)
686+
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
687+
batch_size, batch_alpha,
688+
batch_beta)
684689
cur_beta[batch_beta] += (lr / k) * grad_F_beta
685690

686691
else:
687692
for cur_iter in range(numItermax):
688693
k = np.sqrt(cur_iter + 1)
689694
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
690695
batch_beta = np.random.choice(n_target, batch_size, replace=False)
691-
grad_F_alpha = grad_dF_dalpha(M, reg, cur_alpha, cur_beta,
692-
batch_size, batch_alpha, batch_beta)
693-
grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta,
694-
batch_size, batch_alpha, batch_beta)
696+
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
697+
batch_size, batch_alpha,
698+
batch_beta)
699+
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
700+
batch_size, batch_alpha,
701+
batch_beta)
695702
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
696703
cur_beta[batch_beta] += (lr / k) * grad_F_beta
697704

0 commit comments

Comments
 (0)