|
6 | 6 |
|
7 | 7 |
|
8 | 8 | ############################################################################## |
9 | | -# Optimization toolbox for SEMI - DUAL problem |
| 9 | +# Optimization toolbox for SEMI - DUAL problems |
10 | 10 | ############################################################################## |
11 | | -def coordinate_gradient(b, M, reg, beta, i): |
| 11 | +def coordinate_grad_semi_dual(b, M, reg, beta, i): |
12 | 12 | ''' |
13 | 13 | Compute the coordinate gradient update for regularized discrete |
14 | 14 | distributions for (i, :) |
@@ -161,7 +161,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): |
161 | 161 | sum_stored_gradient = np.zeros(n_target) |
162 | 162 | for _ in range(numItermax): |
163 | 163 | i = np.random.randint(n_source) |
164 | | - cur_coord_grad = a[i] * coordinate_gradient(b, M, reg, cur_beta, i) |
| 164 | + cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg, |
| 165 | + cur_beta, i) |
165 | 166 | sum_stored_gradient += (cur_coord_grad - stored_gradient[i]) |
166 | 167 | stored_gradient[i] = cur_coord_grad |
167 | 168 | cur_beta += lr * (1. / n_source) * sum_stored_gradient |
@@ -245,7 +246,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): |
245 | 246 | for cur_iter in range(numItermax): |
246 | 247 | k = cur_iter + 1 |
247 | 248 | i = np.random.randint(n_source) |
248 | | - cur_coord_grad = coordinate_gradient(b, M, reg, cur_beta, i) |
| 249 | + cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i) |
249 | 250 | cur_beta += (lr / np.sqrt(k)) * cur_coord_grad |
250 | 251 | ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta |
251 | 252 | return ave_beta |
@@ -428,11 +429,12 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, |
428 | 429 |
|
429 | 430 |
|
430 | 431 | ############################################################################## |
431 | | -# Optimization toolbox for DUAL problem |
| 432 | +# Optimization toolbox for DUAL problems |
432 | 433 | ############################################################################## |
433 | 434 |
|
434 | 435 |
|
435 | | -def grad_dF_dalpha(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): |
| 436 | +def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha, |
| 437 | + batch_beta): |
436 | 438 | ''' |
437 | 439 | Computes the partial gradient of F_\W_varepsilon |
438 | 440 |
|
@@ -513,7 +515,8 @@ def grad_dF_dalpha(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): |
513 | 515 | return grad_alpha |
514 | 516 |
|
515 | 517 |
|
516 | | -def grad_dF_dbeta(M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): |
| 518 | +def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha, |
| 519 | + batch_beta): |
517 | 520 | ''' |
518 | 521 | Computes the partial gradient of F_\W_varepsilon |
519 | 522 |
|
@@ -676,22 +679,26 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr, |
676 | 679 | k = np.sqrt(cur_iter + 1) |
677 | 680 | batch_alpha = np.random.choice(n_source, batch_size, replace=False) |
678 | 681 | batch_beta = np.random.choice(n_target, batch_size, replace=False) |
679 | | - grad_F_alpha = grad_dF_dalpha(M, reg, cur_alpha, cur_beta, |
680 | | - batch_size, batch_alpha, batch_beta) |
| 682 | + grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta, |
| 683 | + batch_size, batch_alpha, |
| 684 | + batch_beta) |
681 | 685 | cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha |
682 | | - grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta, |
683 | | - batch_size, batch_alpha, batch_beta) |
| 686 | + grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta, |
| 687 | + batch_size, batch_alpha, |
| 688 | + batch_beta) |
684 | 689 | cur_beta[batch_beta] += (lr / k) * grad_F_beta |
685 | 690 |
|
686 | 691 | else: |
687 | 692 | for cur_iter in range(numItermax): |
688 | 693 | k = np.sqrt(cur_iter + 1) |
689 | 694 | batch_alpha = np.random.choice(n_source, batch_size, replace=False) |
690 | 695 | batch_beta = np.random.choice(n_target, batch_size, replace=False) |
691 | | - grad_F_alpha = grad_dF_dalpha(M, reg, cur_alpha, cur_beta, |
692 | | - batch_size, batch_alpha, batch_beta) |
693 | | - grad_F_beta = grad_dF_dbeta(M, reg, cur_alpha, cur_beta, |
694 | | - batch_size, batch_alpha, batch_beta) |
| 696 | + grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta, |
| 697 | + batch_size, batch_alpha, |
| 698 | + batch_beta) |
| 699 | + grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta, |
| 700 | + batch_size, batch_alpha, |
| 701 | + batch_beta) |
695 | 702 | cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha |
696 | 703 | cur_beta[batch_beta] += (lr / k) * grad_F_beta |
697 | 704 |
|
|
0 commit comments