Skip to content

Commit 64cf2fc

Browse files
committed
tets barycenter
1 parent 83ecc6d commit 64cf2fc

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_bregman.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,32 @@ def test_sinkhorn_variants():
7272
assert np.allclose(G0, Gs, atol=1e-05)
7373
assert np.allclose(G0, Ges, atol=1e-05)
7474
assert np.allclose(G0, Gerr)
75+
76+
77+
def test_bary():
78+
79+
n = 100 # nb bins
80+
81+
# bin positions
82+
x = np.arange(n, dtype=np.float64)
83+
84+
# Gaussian distributions
85+
a1 = ot.datasets.get_1D_gauss(n, m=30, s=10) # m= mean, s= std
86+
a2 = ot.datasets.get_1D_gauss(n, m=40, s=10)
87+
88+
# creating matrix A containing all distributions
89+
A = np.vstack((a1, a2)).T
90+
n_distributions = A.shape[1]
91+
92+
# loss matrix + normalization
93+
M = ot.utils.dist0(n)
94+
M /= M.max()
95+
96+
alpha = 0.5 # 0<=alpha<=1
97+
weights = np.array([1 - alpha, alpha])
98+
99+
# wasserstein
100+
reg = 1e-3
101+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
102+
103+
assert np.allclose(1, np.sum(bary_wass))

0 commit comments

Comments
 (0)