|
10 | 10 |
|
11 | 11 | import ot |
12 | 12 | from ot.datasets import get_1D_gauss as gauss |
13 | | - |
| 13 | +import pytest |
14 | 14 |
|
15 | 15 | def test_doctest(): |
16 | 16 | import doctest |
@@ -117,6 +117,39 @@ def test_emd2_multi(): |
117 | 117 | np.testing.assert_allclose(emd1, emdn) |
118 | 118 |
|
119 | 119 |
|
| 120 | +def test_lp_barycenter(): |
| 121 | + |
| 122 | + a1 = np.array([1.0, 0, 0])[:, None] |
| 123 | + a2 = np.array([0, 0, 1.0])[:, None] |
| 124 | + |
| 125 | + A = np.hstack((a1, a2)) |
| 126 | + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) |
| 127 | + |
| 128 | + # obvious barycenter between two diracs |
| 129 | + bary0 = np.array([0, 1.0, 0]) |
| 130 | + |
| 131 | + bary = ot.lp.barycenter(A, M, [.5, .5]) |
| 132 | + |
| 133 | + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) |
| 134 | + np.testing.assert_allclose(bary.sum(), 1) |
| 135 | + |
| 136 | +@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") |
| 137 | +def test_lp_barycenter_cvxopt(): |
| 138 | + |
| 139 | + a1 = np.array([1.0, 0, 0])[:, None] |
| 140 | + a2 = np.array([0, 0, 1.0])[:, None] |
| 141 | + |
| 142 | + A = np.hstack((a1, a2)) |
| 143 | + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) |
| 144 | + |
| 145 | + # obvious barycenter between two diracs |
| 146 | + bary0 = np.array([0, 1.0, 0]) |
| 147 | + |
| 148 | + bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) |
| 149 | + |
| 150 | + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) |
| 151 | + np.testing.assert_allclose(bary.sum(), 1) |
| 152 | + |
120 | 153 | def test_warnings(): |
121 | 154 | n = 100 # nb bins |
122 | 155 | m = 100 # nb bins |
|
0 commit comments