|
11 | 11 | import numpy as np |
12 | 12 |
|
13 | 13 |
|
14 | | -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 14 | +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, |
| 15 | + stopThr=1e-9, verbose=False, log=False, **kwargs): |
15 | 16 | u""" |
16 | 17 | Solve the entropic regularization optimal transport problem and return the OT matrix |
17 | 18 |
|
@@ -120,7 +121,8 @@ def sink(): |
120 | 121 | return sink() |
121 | 122 |
|
122 | 123 |
|
123 | | -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 124 | +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, |
| 125 | + stopThr=1e-9, verbose=False, log=False, **kwargs): |
124 | 126 | u""" |
125 | 127 | Solve the entropic regularization optimal transport problem and return the loss |
126 | 128 |
|
@@ -233,7 +235,8 @@ def sink(): |
233 | 235 | return sink() |
234 | 236 |
|
235 | 237 |
|
236 | | -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 238 | +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, |
| 239 | + stopThr=1e-9, verbose=False, log=False, **kwargs): |
237 | 240 | """ |
238 | 241 | Solve the entropic regularization optimal transport problem and return the OT matrix |
239 | 242 |
|
@@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l |
403 | 406 | return u.reshape((-1, 1)) * K * v.reshape((1, -1)) |
404 | 407 |
|
405 | 408 |
|
406 | | -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): |
| 409 | +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, |
| 410 | + warmstart=None, verbose=False, print_period=20, log=False, **kwargs): |
407 | 411 | """ |
408 | 412 | Solve the entropic regularization OT problem with log stabilization |
409 | 413 |
|
@@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa |
526 | 530 |
|
527 | 531 | def get_K(alpha, beta): |
528 | 532 | """log space computation""" |
529 | | - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) |
| 533 | + return np.exp(-(M - alpha.reshape((na, 1)) - |
| 534 | + beta.reshape((1, nb))) / reg) |
530 | 535 |
|
531 | 536 | def get_Gamma(alpha, beta, u, v): |
532 | 537 | """log space gamma computation""" |
533 | | - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) |
| 538 | + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / |
| 539 | + reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) |
534 | 540 |
|
535 | 541 | # print(np.min(K)) |
536 | 542 |
|
@@ -620,7 +626,8 @@ def get_Gamma(alpha, beta, u, v): |
620 | 626 | return get_Gamma(alpha, beta, u, v) |
621 | 627 |
|
622 | 628 |
|
623 | | -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): |
| 629 | +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, |
| 630 | + tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): |
624 | 631 | """ |
625 | 632 | Solve the entropic regularization optimal transport problem with log |
626 | 633 | stabilization and epsilon scaling. |
@@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne |
739 | 746 |
|
740 | 747 | def get_K(alpha, beta): |
741 | 748 | """log space computation""" |
742 | | - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) |
| 749 | + return np.exp(-(M - alpha.reshape((na, 1)) - |
| 750 | + beta.reshape((1, nb))) / reg) |
743 | 751 |
|
744 | 752 | # print(np.min(K)) |
745 | 753 | def get_reg(n): # exponential decreasing |
@@ -811,7 +819,8 @@ def projC(gamma, q): |
811 | 819 | return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) |
812 | 820 |
|
813 | 821 |
|
814 | | -def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): |
| 822 | +def barycenter(A, M, reg, weights=None, numItermax=1000, |
| 823 | + stopThr=1e-4, verbose=False, log=False): |
815 | 824 | """Compute the entropic regularized wasserstein barycenter of distributions A |
816 | 825 |
|
817 | 826 | The function solves the following optimization problem: |
@@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F |
904 | 913 | return geometricBar(weights, UKv) |
905 | 914 |
|
906 | 915 |
|
907 | | -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): |
| 916 | +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, |
| 917 | + stopThr=1e-3, verbose=False, log=False): |
908 | 918 | """ |
909 | 919 | Compute the unmixing of an observation with a given dictionary using Wasserstein distance |
910 | 920 |
|
|
0 commit comments