1+ # -*- coding: utf-8 -*-
2+ """
3+ demo of Optimal transport for domain adaptation
4+ """
5+
6+ import numpy as np
7+ import matplotlib .pylab as pl
8+ import ot
9+
10+
11+
12+ #%% parameters
13+
14+ n = 150 # nb bins
15+
16+ xs ,ys = ot .datasets .get_data_classif ('3gauss' ,n )
17+ xt ,yt = ot .datasets .get_data_classif ('3gauss2' ,n )
18+
19+ a ,b = ot .unif (n ),ot .unif (n )
20+ # loss matrix
21+ M = ot .dist (xs ,xt )
22+ #M/=M.max()
23+
24+ #%% plot samples
25+
26+ pl .figure (1 )
27+
28+ pl .subplot (2 ,2 ,1 )
29+ pl .scatter (xs [:,0 ],xs [:,1 ],c = ys ,marker = '+' ,label = 'Source samples' )
30+ pl .legend (loc = 0 )
31+ pl .title ('Source distributions' )
32+
33+ pl .subplot (2 ,2 ,2 )
34+ pl .scatter (xt [:,0 ],xt [:,1 ],c = yt ,marker = 'o' ,label = 'Target samples' )
35+ pl .legend (loc = 0 )
36+ pl .title ('target distributions' )
37+
38+
39+ #%% OT estimation
40+
41+ # EMD
42+
43+
44+ da_emd = ot .da .OTDA ()
45+ da_emd .fit (xs ,xt )
46+
47+ # interpolate samples
48+ xst0 = da_emd .interp ()
49+
50+
51+ # sinkhorn
52+ lambd = 1e-1
53+ da_entrop = ot .da .OTDA_sinkhorn ()
54+ da_entrop .fit (xs ,xt ,reg = lambd )
55+ xsts = da_entrop .interp ()
56+
57+ # Group lasso regularization
58+ reg = 1e-1
59+ eta = 1e0
60+ da_lpl1 = ot .da .OTDA_lpl1 ()
61+ da_lpl1 .fit (xs ,ys ,xt ,reg = lambd ,eta = eta )
62+ xstg = da_lpl1 .interp ()
63+
64+ #%% plot interpolated source samples
65+ pl .figure (4 )
66+
67+ param_img = {'interpolation' :'nearest' ,'cmap' :'jet' }
68+
69+ pl .subplot (2 ,3 ,1 )
70+ pl .imshow (da_emd .G ,** param_img )
71+ pl .title ('OT matrix' )
72+
73+
74+ pl .subplot (2 ,3 ,2 )
75+ pl .imshow (da_entrop .G ,** param_img )
76+ pl .title ('OT matrix sinkhorn' )
77+
78+ pl .subplot (2 ,3 ,3 )
79+ pl .imshow (da_lpl1 .G ,** param_img )
80+ pl .title ('OT matrix Group Lasso' )
81+
82+ pl .subplot (2 ,3 ,4 )
83+ pl .scatter (xt [:,0 ],xt [:,1 ],c = yt ,marker = 'o' ,label = 'Target samples' ,alpha = 0.3 )
84+ pl .scatter (xst0 [:,0 ],xst0 [:,1 ],c = ys ,marker = '+' ,label = 'Transp samples' ,s = 30 )
85+ pl .title ('Interp samples' )
86+ pl .legend (loc = 0 )
87+
88+ pl .subplot (2 ,3 ,5 )
89+ pl .scatter (xt [:,0 ],xt [:,1 ],c = yt ,marker = 'o' ,label = 'Target samples' ,alpha = 0.3 )
90+ pl .scatter (xsts [:,0 ],xsts [:,1 ],c = ys ,marker = '+' ,label = 'Transp samples' ,s = 30 )
91+ pl .title ('Interp samples Sinkhorn' )
92+
93+ pl .subplot (2 ,3 ,6 )
94+ pl .scatter (xt [:,0 ],xt [:,1 ],c = yt ,marker = 'o' ,label = 'Target samples' ,alpha = 0.3 )
95+ pl .scatter (xstg [:,0 ],xstg [:,1 ],c = ys ,marker = '+' ,label = 'Transp samples' ,s = 30 )
96+ pl .title ('Interp samples Group Lasso' )
0 commit comments