Skip to content

Commit 7095e03

Browse files
committed
gtomov barycenter tests
1 parent 64ef33d commit 7095e03

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

test/test_gromov.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_gromov():
4040

4141
G = log['T']
4242

43-
np.testing.assert_allclose(gw, 0, atol=1e-2, rtol=1e-2)
43+
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
4444

4545
# check constratints
4646
np.testing.assert_allclose(
@@ -82,10 +82,37 @@ def test_entropic_gromov():
8282

8383
G = log['T']
8484

85-
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
85+
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
8686

8787
# check constratints
8888
np.testing.assert_allclose(
8989
p, G.sum(1), atol=1e-04) # cf convergence gromov
9090
np.testing.assert_allclose(
9191
q, G.sum(0), atol=1e-04) # cf convergence gromov
92+
93+
94+
def test_gromov_barycenter():
95+
96+
ns = 50
97+
nt = 60
98+
99+
Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
100+
Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
101+
102+
C1 = ot.dist(Xs)
103+
C2 = ot.dist(Xt)
104+
105+
n_samples = 3
106+
Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
107+
[ot.unif(ns), ot.unif(nt)
108+
], ot.unif(n_samples), [.5, .5],
109+
'square_loss', # 5e-4,
110+
max_iter=100, tol=1e-3)
111+
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
112+
113+
Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
114+
[ot.unif(ns), ot.unif(nt)
115+
], ot.unif(n_samples), [.5, .5],
116+
'kl_loss', # 5e-4,
117+
max_iter=100, tol=1e-3)
118+
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))

0 commit comments

Comments
 (0)