File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed
Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change @@ -72,3 +72,32 @@ def test_sinkhorn_variants():
7272 assert np .allclose (G0 , Gs , atol = 1e-05 )
7373 assert np .allclose (G0 , Ges , atol = 1e-05 )
7474 assert np .allclose (G0 , Gerr )
75+
76+
77+ def test_bary ():
78+
79+ n = 100 # nb bins
80+
81+ # bin positions
82+ x = np .arange (n , dtype = np .float64 )
83+
84+ # Gaussian distributions
85+ a1 = ot .datasets .get_1D_gauss (n , m = 30 , s = 10 ) # m= mean, s= std
86+ a2 = ot .datasets .get_1D_gauss (n , m = 40 , s = 10 )
87+
88+ # creating matrix A containing all distributions
89+ A = np .vstack ((a1 , a2 )).T
90+ n_distributions = A .shape [1 ]
91+
92+ # loss matrix + normalization
93+ M = ot .utils .dist0 (n )
94+ M /= M .max ()
95+
96+ alpha = 0.5 # 0<=alpha<=1
97+ weights = np .array ([1 - alpha , alpha ])
98+
99+ # wasserstein
100+ reg = 1e-3
101+ bary_wass = ot .bregman .barycenter (A , M , reg , weights )
102+
103+ assert np .allclose (1 , np .sum (bary_wass ))
You can’t perform that action at this time.
0 commit comments