|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 |
|
| 8 | +def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): |
| 9 | + u""" |
| 10 | + Solve the entropic regularization optimal transport problem |
| 11 | +
|
| 12 | + The function solves the following optimization problem: |
| 13 | +
|
| 14 | + .. math:: |
| 15 | + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) |
| 16 | +
|
| 17 | + s.t. \gamma 1 = a |
| 18 | +
|
| 19 | + \gamma^T 1= b |
| 20 | +
|
| 21 | + \gamma\geq 0 |
| 22 | + where : |
| 23 | +
|
| 24 | + - M is the (ns,nt) metric cost matrix |
| 25 | + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` |
| 26 | + - a and b are source and target weights (sum to 1) |
| 27 | +
|
| 28 | + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ |
| 29 | +
|
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + a : np.ndarray (ns,) |
| 34 | + samples weights in the source domain |
| 35 | + b : np.ndarray (nt,) |
| 36 | + samples in the target domain |
| 37 | + M : np.ndarray (ns,nt) |
| 38 | + loss matrix |
| 39 | + reg : float |
| 40 | + Regularization term >0 |
| 41 | + method : str |
| 42 | + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or |
| 43 | + 'sinkhorn_epsilon_scaling', see those function for specific parameters |
| 44 | + numItermax : int, optional |
| 45 | + Max number of iterations |
| 46 | + stopThr : float, optional |
| 47 | + Stop threshol on error (>0) |
| 48 | + verbose : bool, optional |
| 49 | + Print information along iterations |
| 50 | + log : bool, optional |
| 51 | + record log if True |
| 52 | +
|
| 53 | +
|
| 54 | + Returns |
| 55 | + ------- |
| 56 | + gamma : (ns x nt) ndarray |
| 57 | + Optimal transportation matrix for the given parameters |
| 58 | + log : dict |
| 59 | + log dictionary return only if log==True in parameters |
| 60 | +
|
| 61 | + Examples |
| 62 | + -------- |
| 63 | +
|
| 64 | + >>> import ot |
| 65 | + >>> a=[.5,.5] |
| 66 | + >>> b=[.5,.5] |
| 67 | + >>> M=[[0.,1.],[1.,0.]] |
| 68 | + >>> ot.sinkhorn(a,b,M,1) |
| 69 | + array([[ 0.36552929, 0.13447071], |
| 70 | + [ 0.13447071, 0.36552929]]) |
| 71 | +
|
| 72 | +
|
| 73 | + References |
| 74 | + ---------- |
| 75 | +
|
| 76 | + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 |
| 77 | +
|
| 78 | + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. |
| 79 | +
|
| 80 | + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. |
| 81 | +
|
| 82 | +
|
| 83 | +
|
| 84 | + See Also |
| 85 | + -------- |
| 86 | + ot.lp.emd : Unregularized OT |
| 87 | + ot.optim.cg : General regularized OT |
| 88 | + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] |
| 89 | + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] |
| 90 | + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] |
| 91 | +
|
| 92 | + """ |
| 93 | + |
| 94 | + if method.lower()=='sinkhorn': |
| 95 | + sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, |
| 96 | + stopThr=stopThr, verbose=verbose, log=log,**kwargs) |
| 97 | + elif method.lower()=='sinkhorn_stabilized': |
| 98 | + sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax, |
| 99 | + stopThr=stopThr, verbose=verbose, log=log, **kwargs) |
| 100 | + elif method.lower()=='sinkhorn_epsilon_scaling': |
| 101 | + sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, |
| 102 | + stopThr=stopThr, verbose=verbose, log=log, **kwargs) |
| 103 | + else: |
| 104 | + print('Warning : unknown method using classic Sinkhorn Knopp') |
| 105 | + sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) |
| 106 | + |
| 107 | + return sink() |
| 108 | + |
| 109 | + |
| 110 | + |
| 111 | + |
8 | 112 |
|
9 | | -def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False): |
| 113 | +def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): |
10 | 114 | """ |
11 | 115 | Solve the entropic regularization optimal transport problem |
12 | 116 |
|
@@ -147,7 +251,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa |
147 | 251 | else: |
148 | 252 | return u.reshape((-1,1))*K*v.reshape((1,-1)) |
149 | 253 |
|
150 | | -def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False): |
| 254 | +def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): |
151 | 255 | """ |
152 | 256 | Solve the entropic regularization OT problem with log stabilization |
153 | 257 |
|
@@ -331,7 +435,7 @@ def get_Gamma(alpha,beta,u,v): |
331 | 435 | else: |
332 | 436 | return get_Gamma(alpha,beta,u,v) |
333 | 437 |
|
334 | | -def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False): |
| 438 | +def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs): |
335 | 439 | """ |
336 | 440 | Solve the entropic regularization optimal transport problem with log |
337 | 441 | stabilization and epsilon scaling. |
|
0 commit comments