Skip to content

Commit d52b4ea

Browse files
committed
Fixed typo and merged emd tests
1 parent d43ce6f commit d52b4ea

File tree

3 files changed

+64
-72
lines changed

3 files changed

+64
-72
lines changed

ot/lp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
168168
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
169169

170170
def f(b):
171-
return emd2_c(a,b,M, max_iter)[0]
171+
return emd2_c(a,b,M, numItermax)[0]
172172
res= parmap(f, [b[:,i] for i in range(nb)],processes)
173173
return np.array(res)

test/test_emd.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

test/test_ot.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88
import ot
9+
10+
from ot.datasets import get_1D_gauss as gauss
911

1012

1113
def test_doctest():
@@ -66,9 +68,6 @@ def test_emd_empty():
6668

6769

6870
def test_emd2_multi():
69-
70-
from ot.datasets import get_1D_gauss as gauss
71-
7271
n = 1000 # nb bins
7372

7473
# bin positions
@@ -100,3 +99,64 @@ def test_emd2_multi():
10099
ot.toc('multi proc : {} s')
101100

102101
np.testing.assert_allclose(emd1, emdn)
102+
103+
def test_dual_variables():
104+
#%% parameters
105+
106+
n=5000 # nb bins
107+
m=6000 # nb bins
108+
109+
mean1 = 1000
110+
mean2 = 1100
111+
112+
# bin positions
113+
x=np.arange(n,dtype=np.float64)
114+
y=np.arange(m,dtype=np.float64)
115+
116+
# Gaussian distributions
117+
a=gauss(n,m=mean1,s=5) # m= mean, s= std
118+
119+
b=gauss(m,m=mean2,s=10)
120+
121+
# loss matrix
122+
M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2)
123+
#M/=M.max()
124+
125+
#%%
126+
127+
print('Computing {} EMD '.format(1))
128+
129+
# emd loss 1 proc
130+
ot.tic()
131+
G, alpha, beta = ot.emd(a,b,M, dual_variables=True)
132+
ot.toc('1 proc : {} s')
133+
134+
cost1 = (G * M).sum()
135+
cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
136+
137+
# emd loss 1 proc
138+
ot.tic()
139+
cost_emd2 = ot.emd2(a,b,M)
140+
ot.toc('1 proc : {} s')
141+
142+
ot.tic()
143+
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
144+
ot.toc('1 proc : {} s')
145+
146+
cost2 = (G2 * M.T).sum()
147+
148+
M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1)
149+
150+
# Check that both cost computations are equivalent
151+
np.testing.assert_almost_equal(cost1, cost_emd2)
152+
# Check that dual and primal cost are equal
153+
np.testing.assert_almost_equal(cost1, cost_dual)
154+
# Check symmetry
155+
np.testing.assert_almost_equal(cost1, cost2)
156+
# Check with closed-form solution for gaussians
157+
np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2))
158+
159+
[ind1, ind2] = np.nonzero(G)
160+
161+
# Check that reduced cost is zero on transport arcs
162+
np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size))

0 commit comments

Comments
 (0)