Skip to content

Commit a3497b1

Browse files
committed
Reformat
1 parent d52b4ea commit a3497b1

File tree

1 file changed

+33
-34
lines changed

1 file changed

+33
-34
lines changed

test/test_ot.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
# License: MIT License
66

77
import numpy as np
8+
89
import ot
9-
1010
from ot.datasets import get_1D_gauss as gauss
1111

1212

1313
def test_doctest():
14-
1514
import doctest
1615

1716
# test lp solver
@@ -100,63 +99,63 @@ def test_emd2_multi():
10099

101100
np.testing.assert_allclose(emd1, emdn)
102101

102+
103103
def test_dual_variables():
104-
#%% parameters
105-
106-
n=5000 # nb bins
107-
m=6000 # nb bins
108-
104+
# %% parameters
105+
106+
n = 5000 # nb bins
107+
m = 6000 # nb bins
108+
109109
mean1 = 1000
110110
mean2 = 1100
111-
111+
112112
# bin positions
113-
x=np.arange(n,dtype=np.float64)
114-
y=np.arange(m,dtype=np.float64)
115-
113+
x = np.arange(n, dtype=np.float64)
114+
y = np.arange(m, dtype=np.float64)
115+
116116
# Gaussian distributions
117-
a=gauss(n,m=mean1,s=5) # m= mean, s= std
118-
119-
b=gauss(m,m=mean2,s=10)
120-
117+
a = gauss(n, m=mean1, s=5) # m= mean, s= std
118+
119+
b = gauss(m, m=mean2, s=10)
120+
121121
# loss matrix
122-
M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2)
123-
#M/=M.max()
124-
125-
#%%
126-
122+
M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
123+
# M/=M.max()
124+
125+
# %%
126+
127127
print('Computing {} EMD '.format(1))
128-
128+
129129
# emd loss 1 proc
130130
ot.tic()
131-
G, alpha, beta = ot.emd(a,b,M, dual_variables=True)
131+
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
132132
ot.toc('1 proc : {} s')
133-
133+
134134
cost1 = (G * M).sum()
135135
cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
136-
136+
137137
# emd loss 1 proc
138138
ot.tic()
139-
cost_emd2 = ot.emd2(a,b,M)
139+
cost_emd2 = ot.emd2(a, b, M)
140140
ot.toc('1 proc : {} s')
141-
141+
142142
ot.tic()
143143
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
144144
ot.toc('1 proc : {} s')
145-
145+
146146
cost2 = (G2 * M.T).sum()
147-
148-
M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1)
149-
147+
150148
# Check that both cost computations are equivalent
151149
np.testing.assert_almost_equal(cost1, cost_emd2)
152150
# Check that dual and primal cost are equal
153151
np.testing.assert_almost_equal(cost1, cost_dual)
154152
# Check symmetry
155153
np.testing.assert_almost_equal(cost1, cost2)
156154
# Check with closed-form solution for gaussians
157-
np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2))
158-
155+
np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2))
156+
159157
[ind1, ind2] = np.nonzero(G)
160-
158+
161159
# Check that reduced cost is zero on transport arcs
162-
np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size))
160+
np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2],
161+
np.zeros(ind1.size))

0 commit comments

Comments
 (0)