|
5 | 5 | # License: MIT License |
6 | 6 |
|
7 | 7 | import numpy as np |
| 8 | + |
8 | 9 | import ot |
9 | | - |
10 | 10 | from ot.datasets import get_1D_gauss as gauss |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def test_doctest(): |
14 | | - |
15 | 14 | import doctest |
16 | 15 |
|
17 | 16 | # test lp solver |
@@ -100,63 +99,63 @@ def test_emd2_multi(): |
100 | 99 |
|
101 | 100 | np.testing.assert_allclose(emd1, emdn) |
102 | 101 |
|
| 102 | + |
103 | 103 | def test_dual_variables(): |
104 | | - #%% parameters |
105 | | - |
106 | | - n=5000 # nb bins |
107 | | - m=6000 # nb bins |
108 | | - |
| 104 | + # %% parameters |
| 105 | + |
| 106 | + n = 5000 # nb bins |
| 107 | + m = 6000 # nb bins |
| 108 | + |
109 | 109 | mean1 = 1000 |
110 | 110 | mean2 = 1100 |
111 | | - |
| 111 | + |
112 | 112 | # bin positions |
113 | | - x=np.arange(n,dtype=np.float64) |
114 | | - y=np.arange(m,dtype=np.float64) |
115 | | - |
| 113 | + x = np.arange(n, dtype=np.float64) |
| 114 | + y = np.arange(m, dtype=np.float64) |
| 115 | + |
116 | 116 | # Gaussian distributions |
117 | | - a=gauss(n,m=mean1,s=5) # m= mean, s= std |
118 | | - |
119 | | - b=gauss(m,m=mean2,s=10) |
120 | | - |
| 117 | + a = gauss(n, m=mean1, s=5) # m= mean, s= std |
| 118 | + |
| 119 | + b = gauss(m, m=mean2, s=10) |
| 120 | + |
121 | 121 | # loss matrix |
122 | | - M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) |
123 | | - #M/=M.max() |
124 | | - |
125 | | - #%% |
126 | | - |
| 122 | + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) |
| 123 | + # M/=M.max() |
| 124 | + |
| 125 | + # %% |
| 126 | + |
127 | 127 | print('Computing {} EMD '.format(1)) |
128 | | - |
| 128 | + |
129 | 129 | # emd loss 1 proc |
130 | 130 | ot.tic() |
131 | | - G, alpha, beta = ot.emd(a,b,M, dual_variables=True) |
| 131 | + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) |
132 | 132 | ot.toc('1 proc : {} s') |
133 | | - |
| 133 | + |
134 | 134 | cost1 = (G * M).sum() |
135 | 135 | cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) |
136 | | - |
| 136 | + |
137 | 137 | # emd loss 1 proc |
138 | 138 | ot.tic() |
139 | | - cost_emd2 = ot.emd2(a,b,M) |
| 139 | + cost_emd2 = ot.emd2(a, b, M) |
140 | 140 | ot.toc('1 proc : {} s') |
141 | | - |
| 141 | + |
142 | 142 | ot.tic() |
143 | 143 | G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) |
144 | 144 | ot.toc('1 proc : {} s') |
145 | | - |
| 145 | + |
146 | 146 | cost2 = (G2 * M.T).sum() |
147 | | - |
148 | | - M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) |
149 | | - |
| 147 | + |
150 | 148 | # Check that both cost computations are equivalent |
151 | 149 | np.testing.assert_almost_equal(cost1, cost_emd2) |
152 | 150 | # Check that dual and primal cost are equal |
153 | 151 | np.testing.assert_almost_equal(cost1, cost_dual) |
154 | 152 | # Check symmetry |
155 | 153 | np.testing.assert_almost_equal(cost1, cost2) |
156 | 154 | # Check with closed-form solution for gaussians |
157 | | - np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) |
158 | | - |
| 155 | + np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) |
| 156 | + |
159 | 157 | [ind1, ind2] = np.nonzero(G) |
160 | | - |
| 158 | + |
161 | 159 | # Check that reduced cost is zero on transport arcs |
162 | | - np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) |
| 160 | + np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], |
| 161 | + np.zeros(ind1.size)) |
0 commit comments