@@ -484,66 +484,3 @@ def test_linear_mapping_class():
484484 Cst = np .cov (Xst .T )
485485
486486 np .testing .assert_allclose (Ct , Cst , rtol = 1e-2 , atol = 1e-2 )
487-
488-
489- def test_otda ():
490-
491- n_samples = 150 # nb samples
492- np .random .seed (0 )
493-
494- xs , ys = ot .datasets .make_data_classif ('3gauss' , n_samples )
495- xt , yt = ot .datasets .make_data_classif ('3gauss2' , n_samples )
496-
497- a , b = ot .unif (n_samples ), ot .unif (n_samples )
498-
499- # LP problem
500- da_emd = ot .da .OTDA () # init class
501- da_emd .fit (xs , xt ) # fit distributions
502- da_emd .interp () # interpolation of source samples
503- da_emd .predict (xs ) # interpolation of source samples
504-
505- np .testing .assert_allclose (a , np .sum (da_emd .G , 1 ))
506- np .testing .assert_allclose (b , np .sum (da_emd .G , 0 ))
507-
508- # sinkhorn regularization
509- lambd = 1e-1
510- da_entrop = ot .da .OTDA_sinkhorn ()
511- da_entrop .fit (xs , xt , reg = lambd )
512- da_entrop .interp ()
513- da_entrop .predict (xs )
514-
515- np .testing .assert_allclose (
516- a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
517- np .testing .assert_allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
518-
519- # non-convex Group lasso regularization
520- reg = 1e-1
521- eta = 1e0
522- da_lpl1 = ot .da .OTDA_lpl1 ()
523- da_lpl1 .fit (xs , ys , xt , reg = reg , eta = eta )
524- da_lpl1 .interp ()
525- da_lpl1 .predict (xs )
526-
527- np .testing .assert_allclose (a , np .sum (da_lpl1 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
528- np .testing .assert_allclose (b , np .sum (da_lpl1 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
529-
530- # True Group lasso regularization
531- reg = 1e-1
532- eta = 2e0
533- da_l1l2 = ot .da .OTDA_l1l2 ()
534- da_l1l2 .fit (xs , ys , xt , reg = reg , eta = eta , numItermax = 20 , verbose = True )
535- da_l1l2 .interp ()
536- da_l1l2 .predict (xs )
537-
538- np .testing .assert_allclose (a , np .sum (da_l1l2 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
539- np .testing .assert_allclose (b , np .sum (da_l1l2 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
540-
541- # linear mapping
542- da_emd = ot .da .OTDA_mapping_linear () # init class
543- da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
544- da_emd .predict (xs ) # interpolation of source samples
545-
546- # nonlinear mapping
547- da_emd = ot .da .OTDA_mapping_kernel () # init class
548- da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
549- da_emd .predict (xs ) # interpolation of source samples
0 commit comments