Skip to content

Commit 3797781

Browse files
committed
add doc ot.gpu.bregman
1 parent 55ff888 commit 3797781

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

ot/gpu/bregman.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,80 @@
99

1010
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
1111
log=False, returnAsGPU=False):
12+
"""
13+
Solve the entropic regularization optimal transport problem on GPU
14+
15+
The function solves the following optimization problem:
16+
17+
.. math::
18+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
19+
20+
s.t. \gamma 1 = a
21+
22+
\gamma^T 1= b
23+
24+
\gamma\geq 0
25+
where :
26+
27+
- M is the (ns,nt) metric cost matrix
28+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
29+
- a and b are source and target weights (sum to 1)
30+
31+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
32+
33+
34+
Parameters
35+
----------
36+
a : np.ndarray (ns,)
37+
samples weights in the source domain
38+
b : np.ndarray (nt,)
39+
samples in the target domain
40+
M_GPU : cudamat.CUDAMatrix (ns,nt)
41+
loss matrix
42+
reg : float
43+
Regularization term >0
44+
numItermax : int, optional
45+
Max number of iterations
46+
stopThr : float, optional
47+
Stop threshol on error (>0)
48+
verbose : bool, optional
49+
Print information along iterations
50+
log : bool, optional
51+
record log if True
52+
returnAsGPU : bool, optional
53+
return the OT matrix as a cudamat.CUDAMatrix
54+
55+
Returns
56+
-------
57+
gamma : (ns x nt) ndarray
58+
Optimal transportation matrix for the given parameters
59+
log : dict
60+
log dictionary return only if log==True in parameters
61+
62+
Examples
63+
--------
64+
65+
>>> import ot
66+
>>> a=[.5,.5]
67+
>>> b=[.5,.5]
68+
>>> M=[[0.,1.],[1.,0.]]
69+
>>> ot.sinkhorn(a,b,M,1)
70+
array([[ 0.36552929, 0.13447071],
71+
[ 0.13447071, 0.36552929]])
72+
73+
74+
References
75+
----------
76+
77+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
78+
79+
80+
See Also
81+
--------
82+
ot.lp.emd : Unregularized OT
83+
ot.optim.cg : General regularized OT
84+
85+
"""
1286
# init data
1387
Nini = len(a)
1488
Nfin = len(b)

0 commit comments

Comments
 (0)