@@ -97,3 +97,37 @@ def test_bary():
9797 bary_wass = ot .bregman .barycenter (A , M , reg , weights )
9898
9999 assert np .allclose (1 , np .sum (bary_wass ))
100+
101+ ot .bregman .barycenter (A , M , reg , log = True , verbose = True )
102+
103+
104+ def test_unmix ():
105+
106+ n = 50 # nb bins
107+
108+ # Gaussian distributions
109+ a1 = ot .datasets .get_1D_gauss (n , m = 20 , s = 10 ) # m= mean, s= std
110+ a2 = ot .datasets .get_1D_gauss (n , m = 40 , s = 10 )
111+
112+ a = ot .datasets .get_1D_gauss (n , m = 30 , s = 10 )
113+
114+ # creating matrix A containing all distributions
115+ D = np .vstack ((a1 , a2 )).T
116+
117+ # loss matrix + normalization
118+ M = ot .utils .dist0 (n )
119+ M /= M .max ()
120+
121+ M0 = ot .utils .dist0 (2 )
122+ M0 /= M0 .max ()
123+ h0 = ot .unif (2 )
124+
125+ # wasserstein
126+ reg = 1e-3
127+ um = ot .bregman .unmix (a , D , M , M0 , h0 , reg , 1 , alpha = 0.01 ,)
128+
129+ assert np .allclose (1 , np .sum (um ), rtol = 1e-03 , atol = 1e-03 )
130+ assert np .allclose ([0.5 , 0.5 ], um , rtol = 1e-03 , atol = 1e-03 )
131+
132+ ot .bregman .unmix (a , D , M , M0 , h0 , reg ,
133+ 1 , alpha = 0.01 , log = True , verbose = True )
0 commit comments