|
| 1 | + |
| 2 | + |
| 3 | +import ot |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +# import pytest |
| 7 | + |
| 8 | + |
| 9 | +def test_OTDA(): |
| 10 | + |
| 11 | + n = 150 # nb bins |
| 12 | + |
| 13 | + xs, ys = ot.datasets.get_data_classif('3gauss', n) |
| 14 | + xt, yt = ot.datasets.get_data_classif('3gauss2', n) |
| 15 | + |
| 16 | + a, b = ot.unif(n), ot.unif(n) |
| 17 | + |
| 18 | + # LP problem |
| 19 | + da_emd = ot.da.OTDA() # init class |
| 20 | + da_emd.fit(xs, xt) # fit distributions |
| 21 | + da_emd.interp() # interpolation of source samples |
| 22 | + da_emd.predict(xs) # interpolation of source samples |
| 23 | + |
| 24 | + assert np.allclose(a, np.sum(da_emd.G, 1)) |
| 25 | + assert np.allclose(b, np.sum(da_emd.G, 0)) |
| 26 | + |
| 27 | + # sinkhorn regularization |
| 28 | + lambd = 1e-1 |
| 29 | + da_entrop = ot.da.OTDA_sinkhorn() |
| 30 | + da_entrop.fit(xs, xt, reg=lambd) |
| 31 | + da_entrop.interp() |
| 32 | + da_entrop.predict(xs) |
| 33 | + |
| 34 | + assert np.allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) |
| 35 | + assert np.allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3) |
| 36 | + |
| 37 | + # non-convex Group lasso regularization |
| 38 | + reg = 1e-1 |
| 39 | + eta = 1e0 |
| 40 | + da_lpl1 = ot.da.OTDA_lpl1() |
| 41 | + da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta) |
| 42 | + da_lpl1.interp() |
| 43 | + da_lpl1.predict(xs) |
| 44 | + |
| 45 | + assert np.allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3) |
| 46 | + assert np.allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3) |
| 47 | + |
| 48 | + # True Group lasso regularization |
| 49 | + reg = 1e-1 |
| 50 | + eta = 2e0 |
| 51 | + da_l1l2 = ot.da.OTDA_l1l2() |
| 52 | + da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True) |
| 53 | + da_l1l2.interp() |
| 54 | + da_l1l2.predict(xs) |
| 55 | + |
| 56 | + assert np.allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3) |
| 57 | + assert np.allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3) |
| 58 | + |
| 59 | + # linear mapping |
| 60 | + da_emd = ot.da.OTDA_mapping_linear() # init class |
| 61 | + da_emd.fit(xs, xt, numItermax=10) # fit distributions |
| 62 | + da_emd.predict(xs) # interpolation of source samples |
| 63 | + |
| 64 | + # nonlinear mapping |
| 65 | + da_emd = ot.da.OTDA_mapping_kernel() # init class |
| 66 | + da_emd.fit(xs, xt, numItermax=10) # fit distributions |
| 67 | + da_emd.predict(xs) # interpolation of source samples |
0 commit comments