@@ -490,6 +490,41 @@ def test_barycenter(nx, method, verbose, warn):
490490 ot .bregman .barycenter (A_nx , M_nx , reg , log = True )
491491
492492
493+ @pytest .mark .parametrize ("method, verbose, warn" ,
494+ product (["sinkhorn" , "sinkhorn_stabilized" , "sinkhorn_log" ],
495+ [True , False ], [True , False ]))
496+ def test_barycenter_assymetric_cost (nx , method , verbose , warn ):
497+ n_bins = 20 # nb bins
498+
499+ # Gaussian distributions
500+ A = ot .datasets .make_1D_gauss (n_bins , m = 30 , s = 10 ) # m= mean, s= std
501+
502+ # creating matrix A containing all distributions
503+ A = A [:, None ]
504+
505+ # assymetric loss matrix + normalization
506+ rng = np .random .RandomState (42 )
507+ M = rng .randn (n_bins , n_bins ) ** 2
508+ M /= M .max ()
509+
510+ A_nx , M_nx = nx .from_numpy (A , M )
511+ reg = 1e-2
512+
513+ if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
514+ with pytest .raises (NotImplementedError ):
515+ ot .bregman .barycenter (A_nx , M_nx , reg , method = method )
516+ else :
517+ # wasserstein
518+ bary_wass_np = ot .bregman .barycenter (A , M , reg , method = method , verbose = verbose , warn = warn )
519+ bary_wass , _ = ot .bregman .barycenter (A_nx , M_nx , reg , method = method , log = True )
520+ bary_wass = nx .to_numpy (bary_wass )
521+
522+ np .testing .assert_allclose (1 , np .sum (bary_wass ))
523+ np .testing .assert_allclose (bary_wass , bary_wass_np )
524+
525+ ot .bregman .barycenter (A_nx , M_nx , reg , log = True )
526+
527+
493528@pytest .mark .parametrize ("method, verbose, warn" ,
494529 product (["sinkhorn" , "sinkhorn_log" ],
495530 [True , False ], [True , False ]))
0 commit comments