Skip to content

Commit 0e8363c

Browse files
committed
doc sinkhorn
1 parent cf2d92e commit 0e8363c

File tree

5 files changed

+51
-15
lines changed

5 files changed

+51
-15
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118

119119
# The theme to use for HTML and HTML Help pages. See the documentation for
120120
# a list of builtin themes.
121-
html_theme = 'alabaster'
121+
html_theme = 'sphinx_rtd_theme'
122122

123123
# Theme options are theme-specific and customize the look and feel of a theme
124124
# further. For a list of options available for each theme, see the

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
You can adapt this file completely to your liking, but it should at least
44
contain the root `toctree` directive.
55
6-
Welcome to POT's documentation!
6+
POT's documentation!
77
===============================
88

99
Contents:

examples/demo_OT_1D.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import matplotlib.pylab as pl
1010
import ot
11-
11+
from ot.datasets import get_1D_gauss as gauss
1212

1313

1414
#%% parameters
@@ -19,8 +19,8 @@
1919
x=np.arange(n,dtype=np.float64)
2020

2121
# Gaussian distributions
22-
a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
23-
b=ot.datasets.get_1D_gauss(n,m=60,s=60)
22+
a=gauss(n,m=20,s=20) # m= mean, s= std
23+
b=gauss(n,m=60,s=60)
2424

2525
# loss matrix
2626
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))

ot/bregman.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
2626
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
2727
- a and b are source and target weights (sum to 1)
2828
29-
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_
29+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
3030
3131
3232
Parameters
@@ -46,10 +46,22 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
4646
gamma: (ns x nt) ndarray
4747
Optimal transportation matrix for the given parameters
4848
49+
50+
Examples
51+
--------
52+
53+
>>> a=[.5,.5]
54+
>>> b=[.5,.5]
55+
>>> M=[[0.,1.],[1.,0.]]
56+
>>> ot.sinkhorn(a,b,M,1)
57+
array([[ 0.36552929, 0.13447071],
58+
[ 0.13447071, 0.36552929]])
59+
60+
4961
References
5062
----------
5163
52-
.. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
64+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
5365
5466
5567
See Also
@@ -58,6 +70,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
5870
ot.optim.cg : General regularized OT
5971
6072
"""
73+
74+
a=np.asarray(a,dtype=np.float64)
75+
b=np.asarray(b,dtype=np.float64)
76+
M=np.asarray(M,dtype=np.float64)
77+
78+
if len(a)==0:
79+
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
80+
if len(b)==0:
81+
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
82+
6183
# init data
6284
Nini = len(a)
6385
Nfin = len(b)

ot/lp/__init__.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ def emd(a,b,M):
77
"""
88
Solves the Earth Movers distance problem and returns the optimal transport matrix
99
10-
gamm=emd(a,b,M)
1110
1211
.. math::
1312
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -21,6 +20,8 @@ def emd(a,b,M):
2120
2221
- M is the metric cost matrix
2322
- a and b are the sample weights
23+
24+
Uses the algorithm proposed in [1]_
2425
2526
Parameters
2627
----------
@@ -31,27 +32,40 @@ def emd(a,b,M):
3132
M : (ns,nt) ndarray, float64
3233
loss matrix
3334
35+
Returns
36+
-------
37+
gamma: (ns x nt) ndarray
38+
Optimal transportation matrix for the given parameters
39+
40+
3441
Examples
3542
--------
3643
37-
Simple example with obvious solution. The function :func:emd accepts lists and
38-
perform automatic conversion tu numpy arrays
44+
Simple example with obvious solution. The function emd accepts lists and
45+
perform automatic conversion to numpy arrays
3946
4047
>>> a=[.5,.5]
4148
>>> b=[.5,.5]
4249
>>> M=[[0.,1.],[1.,0.]]
4350
>>> ot.emd(a,b,M)
4451
array([[ 0.5, 0. ],
4552
[ 0. , 0.5]])
53+
54+
References
55+
----------
4656
47-
Returns
48-
-------
49-
gamma: (ns x nt) ndarray
50-
Optimal transportation matrix for the given parameters
51-
57+
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
58+
59+
See Also
60+
--------
61+
ot.bregman.sinkhorn : Entropic regularized OT
62+
ot.optim.cg : General regularized OT
63+
64+
5265
"""
5366
a=np.asarray(a,dtype=np.float64)
5467
b=np.asarray(b,dtype=np.float64)
68+
M=np.asarray(M,dtype=np.float64)
5569

5670
if len(a)==0:
5771
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]

0 commit comments

Comments
 (0)