Skip to content

Commit cf2d92e

Browse files
committed
complete doc emd
1 parent 123a7c7 commit cf2d92e

File tree

4 files changed

+132
-72
lines changed

4 files changed

+132
-72
lines changed

ot/bregman.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
3737
samples in the target domain
3838
M : np.ndarray (ns,nt)
3939
loss matrix
40-
reg: float()
40+
reg: float
4141
Regularization term >0
4242
4343
@@ -54,7 +54,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
5454
5555
See Also
5656
--------
57-
ot.emd.emd : Unregularized optimal ransport
57+
ot.lp.emd : Unregularized OT
58+
ot.optim.cg : General regularized OT
5859
5960
"""
6061
# init data

ot/lp/__init__.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,62 @@
11

22

3-
from .emd import emd
3+
from .emd import emd_c
4+
import numpy as np
5+
6+
def emd(a,b,M):
7+
"""
8+
Solves the Earth Movers distance problem and returns the optimal transport matrix
9+
10+
gamm=emd(a,b,M)
11+
12+
.. math::
13+
\gamma = arg\min_\gamma <\gamma,M>_F
14+
15+
s.t. \gamma 1 = a
16+
17+
\gamma^T 1= b
18+
19+
\gamma\geq 0
20+
where :
21+
22+
- M is the metric cost matrix
23+
- a and b are the sample weights
24+
25+
Parameters
26+
----------
27+
a : (ns,) ndarray, float64
28+
Source histogram (uniform weigth if empty list)
29+
b : (nt,) ndarray, float64
30+
Target histogram (uniform weigth if empty list)
31+
M : (ns,nt) ndarray, float64
32+
loss matrix
33+
34+
Examples
35+
--------
36+
37+
Simple example with obvious solution. The function :func:emd accepts lists and
38+
perform automatic conversion tu numpy arrays
39+
40+
>>> a=[.5,.5]
41+
>>> b=[.5,.5]
42+
>>> M=[[0.,1.],[1.,0.]]
43+
>>> ot.emd(a,b,M)
44+
array([[ 0.5, 0. ],
45+
[ 0. , 0.5]])
46+
47+
Returns
48+
-------
49+
gamma: (ns x nt) ndarray
50+
Optimal transportation matrix for the given parameters
51+
52+
"""
53+
a=np.asarray(a,dtype=np.float64)
54+
b=np.asarray(b,dtype=np.float64)
55+
56+
if len(a)==0:
57+
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
58+
if len(b)==0:
59+
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
60+
61+
return emd_c(a,b,M)
62+

0 commit comments

Comments
 (0)