Skip to content

Commit bd705ed

Browse files
committed
add test yunmlix and bary
1 parent 33f3d30 commit bd705ed

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

test/test_bregman.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,37 @@ def test_bary():
9797
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
9898

9999
assert np.allclose(1, np.sum(bary_wass))
100+
101+
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
102+
103+
104+
def test_unmix():
105+
106+
n = 50 # nb bins
107+
108+
# Gaussian distributions
109+
a1 = ot.datasets.get_1D_gauss(n, m=20, s=10) # m= mean, s= std
110+
a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
111+
112+
a = ot.datasets.get_1D_gauss(n, m=30, s=10)
113+
114+
# creating matrix A containing all distributions
115+
D = np.vstack((a1, a2)).T
116+
117+
# loss matrix + normalization
118+
M = ot.utils.dist0(n)
119+
M /= M.max()
120+
121+
M0 = ot.utils.dist0(2)
122+
M0 /= M0.max()
123+
h0 = ot.unif(2)
124+
125+
# wasserstein
126+
reg = 1e-3
127+
um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
128+
129+
assert np.allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
130+
assert np.allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
131+
132+
ot.bregman.unmix(a, D, M, M0, h0, reg,
133+
1, alpha=0.01, log=True, verbose=True)

test/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def describeRes(r):
4848
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
4949
.format(np.min(r), np.max(r), np.mean(r), np.std(r)))
5050

51-
for n in [50, 100, 500, 1000]:
51+
for n in [50, 100, 500]:
5252
print(n)
5353
a = np.random.rand(n // 4, 100)
5454
labels_a = np.random.randint(10, size=(n // 4))

test/test_ot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_emd2_multi():
7676
# Gaussian distributions
7777
a = gauss(n, m=20, s=5) # m= mean, s= std
7878

79-
ls = np.arange(20, 1000, 10)
79+
ls = np.arange(20, 1000, 20)
8080
nb = len(ls)
8181
b = np.zeros((n, nb))
8282
for i in range(nb):

0 commit comments

Comments
 (0)