Skip to content

Commit c1ccfc4

Browse files
hichamjanatiHicham Janati
andauthored
[MRG] Fix barycenter mass (#375)
* fix transpose in sinkhorn barycenters * add test for assymetric cost barycenters * fix pep8 Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
1 parent 726e84e commit c1ccfc4

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
15111511

15121512
for ii in range(numItermax):
15131513

1514-
UKv = u * nx.dot(K, A / nx.dot(K, u))
1514+
UKv = u * nx.dot(K.T, A / nx.dot(K, u))
15151515
u = (u.T * geometricBar(weights, UKv)).T / UKv
15161516

15171517
if ii % 10 == 1:

test/test_bregman.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,41 @@ def test_barycenter(nx, method, verbose, warn):
490490
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
491491

492492

493+
@pytest.mark.parametrize("method, verbose, warn",
494+
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
495+
[True, False], [True, False]))
496+
def test_barycenter_assymetric_cost(nx, method, verbose, warn):
497+
n_bins = 20 # nb bins
498+
499+
# Gaussian distributions
500+
A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
501+
502+
# creating matrix A containing all distributions
503+
A = A[:, None]
504+
505+
# assymetric loss matrix + normalization
506+
rng = np.random.RandomState(42)
507+
M = rng.randn(n_bins, n_bins) ** 2
508+
M /= M.max()
509+
510+
A_nx, M_nx = nx.from_numpy(A, M)
511+
reg = 1e-2
512+
513+
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
514+
with pytest.raises(NotImplementedError):
515+
ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
516+
else:
517+
# wasserstein
518+
bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn)
519+
bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True)
520+
bary_wass = nx.to_numpy(bary_wass)
521+
522+
np.testing.assert_allclose(1, np.sum(bary_wass))
523+
np.testing.assert_allclose(bary_wass, bary_wass_np)
524+
525+
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
526+
527+
493528
@pytest.mark.parametrize("method, verbose, warn",
494529
product(["sinkhorn", "sinkhorn_log"],
495530
[True, False], [True, False]))

0 commit comments

Comments
 (0)