Skip to content

Commit fdb2f3a

Browse files
committed
add test for barycenter
1 parent 36f4f7e commit fdb2f3a

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

ot/lp/cvx.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
3939
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
4040
4141
The linear program is solved using the interior point solver from scipy.optimize.
42-
If cvxopt solver if installed it can use cvxopt.
42+
If cvxopt solver if installed it can use cvxopt
43+
44+
Note that this problem do not scale well (both in memory and computational time).
4345
4446
Parameters
4547
----------
@@ -114,14 +116,14 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
114116
A_eq = sps.vstack((A_eq1, A_eq2))
115117
b_eq = np.concatenate((b_eq1, b_eq2))
116118

117-
if not cvxopt or solver in ['interior-point']:
119+
if not cvxopt or solver in ['interior-point']:
118120
# cvxopt not installed or interior point
119121

120122
if solver is None:
121123
solver = 'interior-point'
122124

123125
options = {'sparse': True, 'disp': verbose}
124-
sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
126+
sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
125127
options=options)
126128
x = sol.x
127129
b = x[-n:]
@@ -131,8 +133,8 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
131133
h = np.zeros((n_distributions * n2 + n))
132134
G = -sps.eye(n_distributions * n2 + n)
133135

134-
sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
135-
A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
136+
sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
137+
A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
136138
solver=solver)
137139

138140
x = np.array(sol['x'])

test/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ def describe_res(r):
7676
time3 - time2))
7777
describe_res(G2)
7878

79-
np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)
79+
np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3)

test/test_ot.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import ot
1212
from ot.datasets import get_1D_gauss as gauss
13-
13+
import pytest
1414

1515
def test_doctest():
1616
import doctest
@@ -117,6 +117,39 @@ def test_emd2_multi():
117117
np.testing.assert_allclose(emd1, emdn)
118118

119119

120+
def test_lp_barycenter():
121+
122+
a1 = np.array([1.0, 0, 0])[:, None]
123+
a2 = np.array([0, 0, 1.0])[:, None]
124+
125+
A = np.hstack((a1, a2))
126+
M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
127+
128+
# obvious barycenter between two diracs
129+
bary0 = np.array([0, 1.0, 0])
130+
131+
bary = ot.lp.barycenter(A, M, [.5, .5])
132+
133+
np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
134+
np.testing.assert_allclose(bary.sum(), 1)
135+
136+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
137+
def test_lp_barycenter_cvxopt():
138+
139+
a1 = np.array([1.0, 0, 0])[:, None]
140+
a2 = np.array([0, 0, 1.0])[:, None]
141+
142+
A = np.hstack((a1, a2))
143+
M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
144+
145+
# obvious barycenter between two diracs
146+
bary0 = np.array([0, 1.0, 0])
147+
148+
bary = ot.lp.barycenter(A, M, [.5, .5],solver=None)
149+
150+
np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
151+
np.testing.assert_allclose(bary.sum(), 1)
152+
120153
def test_warnings():
121154
n = 100 # nb bins
122155
m = 100 # nb bins

0 commit comments

Comments
 (0)