@@ -42,14 +42,34 @@ def test_class_jax_tf():
4242 otda .fit (Xs = Xs , ys = ys , Xt = Xt )
4343
4444
45+ @pytest .skip_backend ("jax" )
46+ @pytest .skip_backend ("tf" )
47+ @pytest .mark .parametrize ("class_to_test" , [ot .da .EMDTransport , ot .da .SinkhornTransport , ot .da .SinkhornLpl1Transport , ot .da .SinkhornL1l2Transport , ot .da .SinkhornL1l2Transport ])
48+ def test_log_da (nx , class_to_test ):
49+
50+ ns = 50
51+ nt = 50
52+
53+ Xs , ys = make_data_classif ('3gauss' , ns )
54+ Xt , yt = make_data_classif ('3gauss2' , nt )
55+
56+ Xs , ys , Xt , yt = nx .from_numpy (Xs , ys , Xt , yt )
57+
58+ otda = class_to_test (log = True )
59+
60+ # test its computed
61+ otda .fit (Xs = Xs , ys = ys , Xt = Xt )
62+ assert hasattr (otda , "log_" )
63+
64+
4565@pytest .skip_backend ("jax" )
4666@pytest .skip_backend ("tf" )
4767def test_sinkhorn_lpl1_transport_class (nx ):
4868 """test_sinkhorn_transport
4969 """
5070
51- ns = 150
52- nt = 200
71+ ns = 50
72+ nt = 50
5373
5474 Xs , ys = make_data_classif ('3gauss' , ns )
5575 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
136156 """
137157
138158 ns = 50
139- nt = 100
159+ nt = 50
140160
141161 Xs , ys = make_data_classif ('3gauss' , ns )
142162 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
230250 """test_sinkhorn_transport
231251 """
232252
233- ns = 150
234- nt = 200
253+ ns = 50
254+ nt = 50
235255
236256 Xs , ys = make_data_classif ('3gauss' , ns )
237257 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
323343 """test_sinkhorn_transport
324344 """
325345
326- ns = 150
327- nt = 200
346+ ns = 50
347+ nt = 50
328348
329349 Xs , ys = make_data_classif ('3gauss' , ns )
330350 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
402422 """test_sinkhorn_transport
403423 """
404424
405- ns = 150
406- nt = 200
425+ ns = 50
426+ nt = 50
407427
408428 Xs , ys = make_data_classif ('3gauss' , ns )
409429 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
558578@pytest .skip_backend ("jax" )
559579@pytest .skip_backend ("tf" )
560580def test_linear_mapping (nx ):
561- ns = 150
562- nt = 200
581+ ns = 50
582+ nt = 50
563583
564584 Xs , ys = make_data_classif ('3gauss' , ns )
565585 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -579,8 +599,8 @@ def test_linear_mapping(nx):
579599@pytest .skip_backend ("jax" )
580600@pytest .skip_backend ("tf" )
581601def test_linear_mapping_class (nx ):
582- ns = 150
583- nt = 200
602+ ns = 50
603+ nt = 50
584604
585605 Xs , ys = make_data_classif ('3gauss' , ns )
586606 Xt , yt = make_data_classif ('3gauss2' , nt )
@@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
609629 """test_jcpot_transport
610630 """
611631
612- ns1 = 150
613- ns2 = 150
614- nt = 200
632+ ns1 = 50
633+ ns2 = 50
634+ nt = 50
615635
616636 Xs1 , ys1 = make_data_classif ('3gauss' , ns1 )
617637 Xs2 , ys2 = make_data_classif ('3gauss' , ns2 )
@@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
681701 """test_jcpot_barycenter
682702 """
683703
684- ns1 = 150
685- ns2 = 150
686- nt = 200
704+ ns1 = 50
705+ ns2 = 50
706+ nt = 50
687707
688708 sigma = 0.1
689709 np .random .seed (1985 )
@@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
713733def test_emd_laplace_class (nx ):
714734 """test_emd_laplace_transport
715735 """
716- ns = 150
717- nt = 200
736+ ns = 50
737+ nt = 50
718738
719739 Xs , ys = make_data_classif ('3gauss' , ns )
720740 Xt , yt = make_data_classif ('3gauss2' , nt )
0 commit comments