Skip to content

Commit 53b063e

Browse files
committed
better coverage options verbose and log
1 parent 46523dc commit 53b063e

File tree

6 files changed

+42
-17
lines changed

6 files changed

+42
-17
lines changed

ot/bregman.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -909,11 +909,6 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
909909
else:
910910
alpha, beta = warmstart
911911

912-
def get_K(alpha, beta):
913-
"""log space computation"""
914-
return np.exp(-(M - alpha.reshape((dim_a, 1))
915-
- beta.reshape((1, dim_b))) / reg)
916-
917912
# print(np.min(K))
918913
def get_reg(n): # exponential decreasing
919914
return (epsilon0 - reg) * np.exp(-n) + reg

test/test_bregman.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def test_sinkhorn_empty():
5757
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
5858
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
5959

60+
# test empty weights greenkhorn
61+
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
62+
6063

6164
def test_sinkhorn_variants():
6265
# test sinkhorn
@@ -124,7 +127,7 @@ def test_barycenter(method):
124127

125128
# wasserstein
126129
reg = 1e-2
127-
bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
130+
bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
128131

129132
np.testing.assert_allclose(1, np.sum(bary_wass))
130133

@@ -152,9 +155,9 @@ def test_barycenter_stabilization():
152155
reg = 1e-2
153156
bar_stable = ot.bregman.barycenter(A, M, reg, weights,
154157
method="sinkhorn_stabilized",
155-
stopThr=1e-8)
158+
stopThr=1e-8, verbose=True)
156159
bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
157-
stopThr=1e-8)
160+
stopThr=1e-8, verbose=True)
158161
np.testing.assert_allclose(bar, bar_stable)
159162

160163

test/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def df(G):
3838

3939

4040
def test_conditional_gradient2():
41-
n = 4000 # nb samples
41+
n = 1000 # nb samples
4242

4343
mu_s = np.array([0, 0])
4444
cov_s = np.array([[1, 0], [0, 1]])

test/test_partial.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,30 @@
99
import scipy as sp
1010
import ot
1111

12+
def test_partial_wasserstein_lagrange():
13+
14+
n_samples = 20 # nb samples (gaussian)
15+
n_noise = 20 # nb of samples (noise)
16+
17+
mu = np.array([0, 0])
18+
cov = np.array([[1, 0], [0, 2]])
19+
20+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
21+
xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
22+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
23+
xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
24+
25+
M = ot.dist(xs, xt)
26+
27+
p = ot.unif(n_samples + n_noise)
28+
q = ot.unif(n_samples + n_noise)
29+
30+
m = 0.5
31+
32+
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
33+
34+
35+
1236

1337
def test_partial_wasserstein():
1438

@@ -32,7 +56,7 @@ def test_partial_wasserstein():
3256

3357
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
3458
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
35-
log=True)
59+
log=True, verbose=True)
3660

3761
# check constratints
3862
np.testing.assert_equal(

test/test_stochastic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def test_stochastic_asgd():
7070

7171
M = ot.dist(x, x)
7272

73-
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
74-
numItermax=numItermax)
73+
G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
74+
numItermax=numItermax, log=True)
7575

7676
# check constratints
7777
np.testing.assert_allclose(
@@ -145,8 +145,8 @@ def test_stochastic_dual_sgd():
145145

146146
M = ot.dist(x, x)
147147

148-
G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
149-
numItermax=numItermax)
148+
G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
149+
numItermax=numItermax, log=True)
150150

151151
# check constratints
152152
np.testing.assert_allclose(

test/test_unbalanced.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ def test_unbalanced_convergence(method):
3131
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
3232
reg_m=reg_m,
3333
method=method,
34-
log=True)
34+
log=True,
35+
verbose=True)
3536
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
36-
method=method)
37+
method=method,
38+
verbose=True)
3739
# check fixed point equations
3840
# in log-domain
3941
fi = reg_m / (reg_m + epsilon)
@@ -73,7 +75,8 @@ def test_unbalanced_multiple_inputs(method):
7375
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
7476
reg_m=reg_m,
7577
method=method,
76-
log=True)
78+
log=True,
79+
verbose=True)
7780
# check fixed point equations
7881
# in log-domain
7982
fi = reg_m / (reg_m + epsilon)

0 commit comments

Comments
 (0)