|
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | import ot |
| 9 | + |
| 10 | +from ot.datasets import get_1D_gauss as gauss |
9 | 11 |
|
10 | 12 |
|
11 | 13 | def test_doctest(): |
@@ -66,9 +68,6 @@ def test_emd_empty(): |
66 | 68 |
|
67 | 69 |
|
68 | 70 | def test_emd2_multi(): |
69 | | - |
70 | | - from ot.datasets import get_1D_gauss as gauss |
71 | | - |
72 | 71 | n = 1000 # nb bins |
73 | 72 |
|
74 | 73 | # bin positions |
@@ -100,3 +99,64 @@ def test_emd2_multi(): |
100 | 99 | ot.toc('multi proc : {} s') |
101 | 100 |
|
102 | 101 | np.testing.assert_allclose(emd1, emdn) |
| 102 | + |
| 103 | +def test_dual_variables(): |
| 104 | + #%% parameters |
| 105 | + |
| 106 | + n=5000 # nb bins |
| 107 | + m=6000 # nb bins |
| 108 | + |
| 109 | + mean1 = 1000 |
| 110 | + mean2 = 1100 |
| 111 | + |
| 112 | + # bin positions |
| 113 | + x=np.arange(n,dtype=np.float64) |
| 114 | + y=np.arange(m,dtype=np.float64) |
| 115 | + |
| 116 | + # Gaussian distributions |
| 117 | + a=gauss(n,m=mean1,s=5) # m= mean, s= std |
| 118 | + |
| 119 | + b=gauss(m,m=mean2,s=10) |
| 120 | + |
| 121 | + # loss matrix |
| 122 | + M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) |
| 123 | + #M/=M.max() |
| 124 | + |
| 125 | + #%% |
| 126 | + |
| 127 | + print('Computing {} EMD '.format(1)) |
| 128 | + |
| 129 | + # emd loss 1 proc |
| 130 | + ot.tic() |
| 131 | + G, alpha, beta = ot.emd(a,b,M, dual_variables=True) |
| 132 | + ot.toc('1 proc : {} s') |
| 133 | + |
| 134 | + cost1 = (G * M).sum() |
| 135 | + cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) |
| 136 | + |
| 137 | + # emd loss 1 proc |
| 138 | + ot.tic() |
| 139 | + cost_emd2 = ot.emd2(a,b,M) |
| 140 | + ot.toc('1 proc : {} s') |
| 141 | + |
| 142 | + ot.tic() |
| 143 | + G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) |
| 144 | + ot.toc('1 proc : {} s') |
| 145 | + |
| 146 | + cost2 = (G2 * M.T).sum() |
| 147 | + |
| 148 | + M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) |
| 149 | + |
| 150 | + # Check that both cost computations are equivalent |
| 151 | + np.testing.assert_almost_equal(cost1, cost_emd2) |
| 152 | + # Check that dual and primal cost are equal |
| 153 | + np.testing.assert_almost_equal(cost1, cost_dual) |
| 154 | + # Check symmetry |
| 155 | + np.testing.assert_almost_equal(cost1, cost2) |
| 156 | + # Check with closed-form solution for gaussians |
| 157 | + np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) |
| 158 | + |
| 159 | + [ind1, ind2] = np.nonzero(G) |
| 160 | + |
| 161 | + # Check that reduced cost is zero on transport arcs |
| 162 | + np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) |
0 commit comments