Skip to content

Commit b4665fe

Browse files
committed
should pass tests now
1 parent d41ffdb commit b4665fe

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

ot/gromov.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
307307
Print information along iterations
308308
log : bool, optional
309309
record log if True
310+
**kwargs : dict
311+
parameters can be directly pased to the ot.optim.cg solver
310312
311313
Returns
312314
-------

test/test_gromov.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,36 @@ def test_gromov():
2828
C1 /= C1.max()
2929
C2 /= C2.max()
3030

31-
G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
31+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
32+
33+
# check constratints
34+
np.testing.assert_allclose(
35+
p, G.sum(1), atol=1e-04) # cf convergence gromov
36+
np.testing.assert_allclose(
37+
q, G.sum(0), atol=1e-04) # cf convergence gromov
38+
39+
40+
def test_entropic_gromov():
41+
n_samples = 50 # nb samples
42+
43+
mu_s = np.array([0, 0])
44+
cov_s = np.array([[1, 0], [0, 1]])
45+
46+
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
47+
48+
xt = xs[::-1].copy()
49+
50+
p = ot.unif(n_samples)
51+
q = ot.unif(n_samples)
52+
53+
C1 = ot.dist(xs, xs)
54+
C2 = ot.dist(xt, xt)
55+
56+
C1 /= C1.max()
57+
C2 /= C2.max()
58+
59+
G = ot.gromov.entropic_gromov_wasserstein(
60+
C1, C2, p, q, 'square_loss', epsilon=5e-4)
3261

3362
# check constratints
3463
np.testing.assert_allclose(

test/test_plot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
def test_plot1D_mat():
1313

1414
import ot
15+
import ot.plot
1516

1617
n_bins = 100 # nb bins
1718

@@ -32,6 +33,7 @@ def test_plot1D_mat():
3233
def test_plot2D_samples_mat():
3334

3435
import ot
36+
import ot.plot
3537

3638
n_bins = 50 # nb samples
3739

0 commit comments

Comments
 (0)