Skip to content

Commit f204e98

Browse files
committed
add test da 58% coverage
1 parent bd705ed commit f204e98

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

test/test_da.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2+
3+
import ot
4+
import numpy as np
5+
6+
# import pytest
7+
8+
9+
def test_OTDA():
10+
11+
n = 150 # nb bins
12+
13+
xs, ys = ot.datasets.get_data_classif('3gauss', n)
14+
xt, yt = ot.datasets.get_data_classif('3gauss2', n)
15+
16+
a, b = ot.unif(n), ot.unif(n)
17+
18+
# LP problem
19+
da_emd = ot.da.OTDA() # init class
20+
da_emd.fit(xs, xt) # fit distributions
21+
da_emd.interp() # interpolation of source samples
22+
da_emd.predict(xs) # interpolation of source samples
23+
24+
assert np.allclose(a, np.sum(da_emd.G, 1))
25+
assert np.allclose(b, np.sum(da_emd.G, 0))
26+
27+
# sinkhorn regularization
28+
lambd = 1e-1
29+
da_entrop = ot.da.OTDA_sinkhorn()
30+
da_entrop.fit(xs, xt, reg=lambd)
31+
da_entrop.interp()
32+
da_entrop.predict(xs)
33+
34+
assert np.allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
35+
assert np.allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
36+
37+
# non-convex Group lasso regularization
38+
reg = 1e-1
39+
eta = 1e0
40+
da_lpl1 = ot.da.OTDA_lpl1()
41+
da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
42+
da_lpl1.interp()
43+
da_lpl1.predict(xs)
44+
45+
assert np.allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
46+
assert np.allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
47+
48+
# True Group lasso regularization
49+
reg = 1e-1
50+
eta = 2e0
51+
da_l1l2 = ot.da.OTDA_l1l2()
52+
da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
53+
da_l1l2.interp()
54+
da_l1l2.predict(xs)
55+
56+
assert np.allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
57+
assert np.allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
58+
59+
# linear mapping
60+
da_emd = ot.da.OTDA_mapping_linear() # init class
61+
da_emd.fit(xs, xt, numItermax=10) # fit distributions
62+
da_emd.predict(xs) # interpolation of source samples
63+
64+
# nonlinear mapping
65+
da_emd = ot.da.OTDA_mapping_kernel() # init class
66+
da_emd.fit(xs, xt, numItermax=10) # fit distributions
67+
da_emd.predict(xs) # interpolation of source samples

0 commit comments

Comments
 (0)