Skip to content

Commit 12d9b3f

Browse files
committed
Return dual variables in an optional dictionary
Also removed some code duplication
1 parent f8c1c87 commit 12d9b3f

File tree

3 files changed

+25
-88
lines changed

3 files changed

+25
-88
lines changed

ot/lp/__init__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
import numpy as np
1111
# import compiled emd
12-
from .emd_wrap import emd_c, emd2_c
12+
from .emd_wrap import emd_c
1313
from ..utils import parmap
1414
import multiprocessing
1515

1616

17-
def emd(a, b, M, numItermax=100000, dual_variables=False):
17+
def emd(a, b, M, numItermax=100000, log=False):
1818
"""Solves the Earth Movers distance problem and returns the OT matrix
1919
2020
@@ -42,11 +42,17 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
4242
numItermax : int, optional (default=100000)
4343
The maximum number of iterations before stopping the optimization
4444
algorithm if it has not converged.
45+
log: boolean, optional (default=False)
46+
If True, returns a dictionary containing the cost and dual
47+
variables. Otherwise returns only the optimal transportation matrix.
4548
4649
Returns
4750
-------
4851
gamma: (ns x nt) ndarray
4952
Optimal transportation matrix for the given parameters
53+
log: dict
54+
If input log is true, a dictionary containing the cost and dual
55+
variables
5056
5157
5258
Examples
@@ -86,9 +92,13 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
8692
if len(b) == 0:
8793
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
8894

89-
G, alpha, beta = emd_c(a, b, M, numItermax)
90-
if dual_variables:
91-
return G, alpha, beta
95+
G, cost, u, v = emd_c(a, b, M, numItermax)
96+
if log:
97+
log = {}
98+
log['cost'] = cost
99+
log['u'] = u
100+
log['v'] = v
101+
return G, log
92102
return G
93103

94104
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
@@ -163,11 +173,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
163173
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
164174

165175
if len(b.shape)==1:
166-
return emd2_c(a, b, M, numItermax)[0]
176+
return emd_c(a, b, M, numItermax)[1]
167177
nb = b.shape[1]
168178
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
169179

170180
def f(b):
171-
return emd2_c(a,b,M, numItermax)[0]
181+
return emd_c(a,b,M, numItermax)[1]
172182
res= parmap(f, [b[:,i] for i in range(nb)],processes)
173183
return np.array(res)

ot/lp/emd_wrap.pyx

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -86,71 +86,4 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
8686
elif resultSolver == MAX_ITER_REACHED:
8787
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
8888

89-
return G, alpha, beta
90-
91-
@cython.boundscheck(False)
92-
@cython.wraparound(False)
93-
def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
94-
"""
95-
Solves the Earth Movers distance problem and returns the optimal transport loss
96-
97-
gamm=emd(a,b,M)
98-
99-
.. math::
100-
\gamma = arg\min_\gamma <\gamma,M>_F
101-
102-
s.t. \gamma 1 = a
103-
104-
\gamma^T 1= b
105-
106-
\gamma\geq 0
107-
where :
108-
109-
- M is the metric cost matrix
110-
- a and b are the sample weights
111-
112-
Parameters
113-
----------
114-
a : (ns,) ndarray, float64
115-
source histogram
116-
b : (nt,) ndarray, float64
117-
target histogram
118-
M : (ns,nt) ndarray, float64
119-
loss matrix
120-
numItermax : int
121-
The maximum number of iterations before stopping the optimization
122-
algorithm if it has not converged.
123-
124-
125-
Returns
126-
-------
127-
gamma: (ns x nt) ndarray
128-
Optimal transportation matrix for the given parameters
129-
130-
"""
131-
cdef int n1= M.shape[0]
132-
cdef int n2= M.shape[1]
133-
134-
cdef double cost=0
135-
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
136-
137-
cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
138-
cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2])
139-
140-
if not len(a):
141-
a=np.ones((n1,))/n1
142-
143-
if not len(b):
144-
b=np.ones((n2,))/n2
145-
# calling the function
146-
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
147-
if resultSolver != OPTIMAL:
148-
if resultSolver == INFEASIBLE:
149-
warnings.warn("Problem infeasible. Check that a and b are in the simplex")
150-
elif resultSolver == UNBOUNDED:
151-
warnings.warn("Problem unbounded")
152-
elif resultSolver == MAX_ITER_REACHED:
153-
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
154-
155-
return cost, alpha, beta
156-
89+
return G, cost, alpha, beta

test/test_ot.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,27 +124,26 @@ def test_warnings():
124124
# %%
125125

126126
print('Computing {} EMD '.format(1))
127-
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
128127
with warnings.catch_warnings(record=True) as w:
129128
# Cause all warnings to always be triggered.
130129
warnings.simplefilter("always")
131130
# Trigger a warning.
132131
print('Computing {} EMD '.format(1))
133-
G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1)
132+
G = ot.emd(a, b, M, numItermax=1)
134133
# Verify some things
135134
assert "numItermax" in str(w[-1].message)
136135
assert len(w) == 1
137136
# Trigger a warning.
138137
a[0]=100
139138
print('Computing {} EMD '.format(2))
140-
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
139+
G = ot.emd(a, b, M)
141140
# Verify some things
142141
assert "infeasible" in str(w[-1].message)
143142
assert len(w) == 2
144143
# Trigger a warning.
145144
a[0]=-1
146145
print('Computing {} EMD '.format(2))
147-
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
146+
G = ot.emd(a, b, M)
148147
# Verify some things
149148
assert "infeasible" in str(w[-1].message)
150149
assert len(w) == 3
@@ -176,16 +175,11 @@ def test_dual_variables():
176175

177176
# emd loss 1 proc
178177
ot.tic()
179-
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
178+
G, log = ot.emd(a, b, M, log=True)
180179
ot.toc('1 proc : {} s')
181180

182181
cost1 = (G * M).sum()
183-
cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
184-
185-
# emd loss 1 proc
186-
ot.tic()
187-
cost_emd2 = ot.emd2(a, b, M)
188-
ot.toc('1 proc : {} s')
182+
cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v'])
189183

190184
ot.tic()
191185
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
@@ -194,7 +188,7 @@ def test_dual_variables():
194188
cost2 = (G2 * M.T).sum()
195189

196190
# Check that both cost computations are equivalent
197-
np.testing.assert_almost_equal(cost1, cost_emd2)
191+
np.testing.assert_almost_equal(cost1, log['cost'])
198192
# Check that dual and primal cost are equal
199193
np.testing.assert_almost_equal(cost1, cost_dual)
200194
# Check symmetry
@@ -205,5 +199,5 @@ def test_dual_variables():
205199
[ind1, ind2] = np.nonzero(G)
206200

207201
# Check that reduced cost is zero on transport arcs
208-
np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2],
202+
np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2],
209203
np.zeros(ind1.size))

0 commit comments

Comments
 (0)