1+ # -*- coding: utf-8 -*-
2+ """
3+ Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8]
4+
5+ [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized
6+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
7+ [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
8+ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
9+
10+
11+ """
12+
13+ import numpy as np
14+ import scipy .ndimage as spi
15+ import matplotlib .pylab as pl
16+ import ot
17+
18+
19+ #%% Loading images
20+
21+ I1 = spi .imread ('../data/ocean_day.jpg' ).astype (np .float64 )/ 256
22+ I2 = spi .imread ('../data/ocean_sunset.jpg' ).astype (np .float64 )/ 256
23+
24+ #%% Plot images
25+
26+ pl .figure (1 )
27+
28+ pl .subplot (1 ,2 ,1 )
29+ pl .imshow (I1 )
30+ pl .title ('Image 1' )
31+
32+ pl .subplot (1 ,2 ,2 )
33+ pl .imshow (I2 )
34+ pl .title ('Image 2' )
35+
36+ pl .show ()
37+
38+ #%% Image conversion and dataset generation
39+
40+ def im2mat (I ):
41+ """Converts and image to matrix (one pixel per line)"""
42+ return I .reshape ((I .shape [0 ]* I .shape [1 ],I .shape [2 ]))
43+
44+ def mat2im (X ,shape ):
45+ """Converts back a matrix to an image"""
46+ return X .reshape (shape )
47+
48+ X1 = im2mat (I1 )
49+ X2 = im2mat (I2 )
50+
51+ # training samples
52+ nb = 1000
53+ idx1 = np .random .randint (X1 .shape [0 ],size = (nb ,))
54+ idx2 = np .random .randint (X2 .shape [0 ],size = (nb ,))
55+
56+ xs = X1 [idx1 ,:]
57+ xt = X2 [idx2 ,:]
58+
59+ #%% Plot image distributions
60+
61+
62+ pl .figure (2 ,(10 ,5 ))
63+
64+ pl .subplot (1 ,2 ,1 )
65+ pl .scatter (xs [:,0 ],xs [:,2 ],c = xs )
66+ pl .axis ([0 ,1 ,0 ,1 ])
67+ pl .xlabel ('Red' )
68+ pl .ylabel ('Blue' )
69+ pl .title ('Image 1' )
70+
71+ pl .subplot (1 ,2 ,2 )
72+ #pl.imshow(I2)
73+ pl .scatter (xt [:,0 ],xt [:,2 ],c = xt )
74+ pl .axis ([0 ,1 ,0 ,1 ])
75+ pl .xlabel ('Red' )
76+ pl .ylabel ('Blue' )
77+ pl .title ('Image 2' )
78+
79+ pl .show ()
80+
81+
82+
83+ #%% domain adaptation between images
84+ def minmax (I ):
85+ return np .minimum (np .maximum (I ,0 ),1 )
86+ # LP problem
87+ da_emd = ot .da .OTDA () # init class
88+ da_emd .fit (xs ,xt ) # fit distributions
89+
90+ X1t = da_emd .predict (X1 ) # out of sample
91+ I1t = minmax (mat2im (X1t ,I1 .shape ))
92+
93+ # sinkhorn regularization
94+ lambd = 1e-1
95+ da_entrop = ot .da .OTDA_sinkhorn ()
96+ da_entrop .fit (xs ,xt ,reg = lambd )
97+
98+ X1te = da_entrop .predict (X1 )
99+ I1te = minmax (mat2im (X1te ,I1 .shape ))
100+
101+ # linear mapping estimation
102+ eta = 1e-8 # quadratic regularization for regression
103+ mu = 1e0 # weight of the OT linear term
104+ bias = True # estimate a bias
105+
106+ ot_mapping = ot .da .OTDA_mapping_linear ()
107+ ot_mapping .fit (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,numItermax = 20 ,verbose = True )
108+
109+ X1tl = ot_mapping .predict (X1 ) # use the estimated mapping
110+ I1tl = minmax (mat2im (X1tl ,I1 .shape ))
111+
112+ # nonlinear mapping estimation
113+ eta = 1e-2 # quadratic regularization for regression
114+ mu = 1e0 # weight of the OT linear term
115+ bias = False # estimate a bias
116+ sigma = 1 # sigma bandwidth fot gaussian kernel
117+
118+
119+ ot_mapping_kernel = ot .da .OTDA_mapping_kernel ()
120+ ot_mapping_kernel .fit (xs ,xt ,mu = mu ,eta = eta ,sigma = sigma ,bias = bias ,numItermax = 10 ,verbose = True )
121+
122+ X1tn = ot_mapping_kernel .predict (X1 ) # use the estimated mapping
123+ I1tn = minmax (mat2im (X1tn ,I1 .shape ))
124+ #%% plot images
125+
126+
127+ pl .figure (2 ,(10 ,8 ))
128+
129+ pl .subplot (2 ,3 ,1 )
130+
131+ pl .imshow (I1 )
132+ pl .title ('Im. 1' )
133+
134+ pl .subplot (2 ,3 ,2 )
135+
136+ pl .imshow (I2 )
137+ pl .title ('Im. 2' )
138+
139+
140+ pl .subplot (2 ,3 ,3 )
141+ pl .imshow (I1t )
142+ pl .title ('Im. 1 Interp LP' )
143+
144+ pl .subplot (2 ,3 ,4 )
145+ pl .imshow (I1te )
146+ pl .title ('Im. 1 Interp Entrop' )
147+
148+
149+ pl .subplot (2 ,3 ,5 )
150+ pl .imshow (I1tl )
151+ pl .title ('Im. 1 Linear mapping' )
152+
153+ pl .subplot (2 ,3 ,6 )
154+ pl .imshow (I1tn )
155+ pl .title ('Im. 1 nonlinear mapping' )
156+
157+ pl .show ()
0 commit comments