Skip to content

Commit ab65f86

Browse files
committed
Added log option to muliprocess emd
1 parent 12d9b3f commit ab65f86

File tree

2 files changed

+57
-39
lines changed

2 files changed

+57
-39
lines changed

ot/lp/__init__.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
#
88
# License: MIT License
99

10+
import multiprocessing
11+
1012
import numpy as np
13+
1114
# import compiled emd
1215
from .emd_wrap import emd_c
1316
from ..utils import parmap
14-
import multiprocessing
1517

1618

1719
def emd(a, b, M, numItermax=100000, log=False):
@@ -88,9 +90,9 @@ def emd(a, b, M, numItermax=100000, log=False):
8890

8991
# if empty array given then use unifor distributions
9092
if len(a) == 0:
91-
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
93+
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
9294
if len(b) == 0:
93-
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
95+
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
9496

9597
G, cost, u, v = emd_c(a, b, M, numItermax)
9698
if log:
@@ -101,7 +103,8 @@ def emd(a, b, M, numItermax=100000, log=False):
101103
return G, log
102104
return G
103105

104-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
106+
107+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False):
105108
"""Solves the Earth Movers distance problem and returns the loss
106109
107110
.. math::
@@ -168,16 +171,26 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
168171

169172
# if empty array given then use unifor distributions
170173
if len(a) == 0:
171-
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
174+
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
172175
if len(b) == 0:
173-
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
174-
175-
if len(b.shape)==1:
176-
return emd_c(a, b, M, numItermax)[1]
176+
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
177+
178+
if log:
179+
def f(b):
180+
G, cost, u, v = emd_c(a, b, M, numItermax)
181+
log = {}
182+
log['G'] = G
183+
log['u'] = u
184+
log['v'] = v
185+
return [cost, log]
186+
else:
187+
def f(b):
188+
return emd_c(a, b, M, numItermax)[1]
189+
190+
if len(b.shape) == 1:
191+
return f(b)
177192
nb = b.shape[1]
178193
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
179194

180-
def f(b):
181-
return emd_c(a,b,M, numItermax)[1]
182-
res= parmap(f, [b[:,i] for i in range(nb)],processes)
183-
return np.array(res)
195+
res = parmap(f, [b[:, i] for i in range(nb)], processes)
196+
return res

test/test_ot.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
#
55
# License: MIT License
66

7+
import warnings
8+
79
import numpy as np
810

911
import ot
1012
from ot.datasets import get_1D_gauss as gauss
11-
import warnings
1213

1314

1415
def test_doctest():
@@ -100,6 +101,21 @@ def test_emd2_multi():
100101

101102
np.testing.assert_allclose(emd1, emdn)
102103

104+
# emd loss multipro proc with log
105+
ot.tic()
106+
emdn = ot.emd2(a, b, M, log=True)
107+
ot.toc('multi proc : {} s')
108+
109+
for i in range(len(emdn)):
110+
emd = emdn[i]
111+
log = emd[1]
112+
cost = emd[0]
113+
check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost)
114+
emdn[i] = cost
115+
116+
emdn = np.array(emdn)
117+
np.testing.assert_allclose(emd1, emdn)
118+
103119

104120
def test_warnings():
105121
n = 100 # nb bins
@@ -119,32 +135,22 @@ def test_warnings():
119135

120136
# loss matrix
121137
M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
122-
# M/=M.max()
123-
124-
# %%
125138

126139
print('Computing {} EMD '.format(1))
127140
with warnings.catch_warnings(record=True) as w:
128-
# Cause all warnings to always be triggered.
129141
warnings.simplefilter("always")
130-
# Trigger a warning.
131142
print('Computing {} EMD '.format(1))
132143
G = ot.emd(a, b, M, numItermax=1)
133-
# Verify some things
134144
assert "numItermax" in str(w[-1].message)
135145
assert len(w) == 1
136-
# Trigger a warning.
137-
a[0]=100
146+
a[0] = 100
138147
print('Computing {} EMD '.format(2))
139148
G = ot.emd(a, b, M)
140-
# Verify some things
141149
assert "infeasible" in str(w[-1].message)
142150
assert len(w) == 2
143-
# Trigger a warning.
144-
a[0]=-1
151+
a[0] = -1
145152
print('Computing {} EMD '.format(2))
146153
G = ot.emd(a, b, M)
147-
# Verify some things
148154
assert "infeasible" in str(w[-1].message)
149155
assert len(w) == 3
150156

@@ -167,9 +173,6 @@ def test_dual_variables():
167173

168174
# loss matrix
169175
M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
170-
# M/=M.max()
171-
172-
# %%
173176

174177
print('Computing {} EMD '.format(1))
175178

@@ -178,26 +181,28 @@ def test_dual_variables():
178181
G, log = ot.emd(a, b, M, log=True)
179182
ot.toc('1 proc : {} s')
180183

181-
cost1 = (G * M).sum()
182-
cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v'])
183-
184184
ot.tic()
185185
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
186186
ot.toc('1 proc : {} s')
187187

188-
cost2 = (G2 * M.T).sum()
188+
cost1 = (G * M).sum()
189+
# Check symmetry
190+
np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum())
191+
# Check with closed-form solution for gaussians
192+
np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2))
189193

190194
# Check that both cost computations are equivalent
191195
np.testing.assert_almost_equal(cost1, log['cost'])
196+
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
197+
198+
199+
def check_duality_gap(a, b, M, G, u, v, cost):
200+
cost_dual = np.vdot(a, u) + np.vdot(b, v)
192201
# Check that dual and primal cost are equal
193-
np.testing.assert_almost_equal(cost1, cost_dual)
194-
# Check symmetry
195-
np.testing.assert_almost_equal(cost1, cost2)
196-
# Check with closed-form solution for gaussians
197-
np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2))
202+
np.testing.assert_almost_equal(cost_dual, cost)
198203

199204
[ind1, ind2] = np.nonzero(G)
200205

201206
# Check that reduced cost is zero on transport arcs
202-
np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2],
207+
np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2],
203208
np.zeros(ind1.size))

0 commit comments

Comments
 (0)