|
12 | 12 | """ |
13 | 13 |
|
14 | 14 | import numpy as np |
15 | | -import scipy.ndimage as spi |
| 15 | +from scipy import ndimage |
16 | 16 | import matplotlib.pylab as pl |
17 | 17 | import ot |
18 | 18 |
|
19 | 19 |
|
20 | 20 | #%% Loading images |
21 | 21 |
|
22 | | -I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 |
23 | | -I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 |
| 22 | +I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 |
| 23 | +I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 |
24 | 24 |
|
25 | 25 | #%% Plot images |
26 | 26 |
|
27 | | -pl.figure(1) |
28 | | - |
29 | | -pl.subplot(1,2,1) |
| 27 | +pl.figure(1, figsize=(6.4, 3)) |
| 28 | +pl.subplot(1, 2, 1) |
30 | 29 | pl.imshow(I1) |
| 30 | +pl.axis('off') |
31 | 31 | pl.title('Image 1') |
32 | 32 |
|
33 | | -pl.subplot(1,2,2) |
| 33 | +pl.subplot(1, 2, 2) |
34 | 34 | pl.imshow(I2) |
| 35 | +pl.axis('off') |
35 | 36 | pl.title('Image 2') |
36 | | - |
37 | | -pl.show() |
| 37 | +pl.tight_layout() |
38 | 38 |
|
39 | 39 | #%% Image conversion and dataset generation |
40 | 40 |
|
41 | 41 | def im2mat(I): |
42 | 42 | """Converts and image to matrix (one pixel per line)""" |
43 | | - return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) |
| 43 | + return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) |
| 44 | + |
44 | 45 |
|
45 | | -def mat2im(X,shape): |
| 46 | +def mat2im(X, shape): |
46 | 47 | """Converts back a matrix to an image""" |
47 | 48 | return X.reshape(shape) |
48 | 49 |
|
49 | | -X1=im2mat(I1) |
50 | | -X2=im2mat(I2) |
| 50 | +X1 = im2mat(I1) |
| 51 | +X2 = im2mat(I2) |
51 | 52 |
|
52 | 53 | # training samples |
53 | | -nb=1000 |
54 | | -idx1=np.random.randint(X1.shape[0],size=(nb,)) |
55 | | -idx2=np.random.randint(X2.shape[0],size=(nb,)) |
| 54 | +nb = 1000 |
| 55 | +idx1 = np.random.randint(X1.shape[0], size=(nb,)) |
| 56 | +idx2 = np.random.randint(X2.shape[0], size=(nb,)) |
56 | 57 |
|
57 | | -xs=X1[idx1,:] |
58 | | -xt=X2[idx2,:] |
| 58 | +xs = X1[idx1, :] |
| 59 | +xt = X2[idx2, :] |
59 | 60 |
|
60 | 61 | #%% Plot image distributions |
61 | 62 |
|
62 | 63 |
|
63 | | -pl.figure(2,(10,5)) |
| 64 | +pl.figure(2, figsize=(6.4, 5)) |
64 | 65 |
|
65 | | -pl.subplot(1,2,1) |
66 | | -pl.scatter(xs[:,0],xs[:,2],c=xs) |
67 | | -pl.axis([0,1,0,1]) |
| 66 | +pl.subplot(1, 2, 1) |
| 67 | +pl.scatter(xs[:, 0], xs[:, 2], c=xs) |
| 68 | +pl.axis([0, 1, 0, 1]) |
68 | 69 | pl.xlabel('Red') |
69 | 70 | pl.ylabel('Blue') |
70 | 71 | pl.title('Image 1') |
71 | 72 |
|
72 | | -pl.subplot(1,2,2) |
73 | | -#pl.imshow(I2) |
74 | | -pl.scatter(xt[:,0],xt[:,2],c=xt) |
75 | | -pl.axis([0,1,0,1]) |
| 73 | +pl.subplot(1, 2, 2) |
| 74 | +pl.scatter(xt[:, 0], xt[:, 2], c=xt) |
| 75 | +pl.axis([0, 1, 0, 1]) |
76 | 76 | pl.xlabel('Red') |
77 | 77 | pl.ylabel('Blue') |
78 | 78 | pl.title('Image 2') |
79 | | - |
80 | | -pl.show() |
81 | | - |
82 | | - |
| 79 | +pl.tight_layout() |
83 | 80 |
|
84 | 81 | #%% domain adaptation between images |
85 | 82 | def minmax(I): |
86 | | - return np.minimum(np.maximum(I,0),1) |
| 83 | + return np.clip(I, 0, 1) |
| 84 | + |
87 | 85 | # LP problem |
88 | | -da_emd=ot.da.OTDA() # init class |
89 | | -da_emd.fit(xs,xt) # fit distributions |
| 86 | +da_emd = ot.da.OTDA() # init class |
| 87 | +da_emd.fit(xs, xt) # fit distributions |
90 | 88 |
|
91 | | -X1t=da_emd.predict(X1) # out of sample |
92 | | -I1t=minmax(mat2im(X1t,I1.shape)) |
| 89 | +X1t = da_emd.predict(X1) # out of sample |
| 90 | +I1t = minmax(mat2im(X1t, I1.shape)) |
93 | 91 |
|
94 | 92 | # sinkhorn regularization |
95 | | -lambd=1e-1 |
96 | | -da_entrop=ot.da.OTDA_sinkhorn() |
97 | | -da_entrop.fit(xs,xt,reg=lambd) |
| 93 | +lambd = 1e-1 |
| 94 | +da_entrop = ot.da.OTDA_sinkhorn() |
| 95 | +da_entrop.fit(xs, xt, reg=lambd) |
98 | 96 |
|
99 | | -X1te=da_entrop.predict(X1) |
100 | | -I1te=minmax(mat2im(X1te,I1.shape)) |
| 97 | +X1te = da_entrop.predict(X1) |
| 98 | +I1te = minmax(mat2im(X1te, I1.shape)) |
101 | 99 |
|
102 | 100 | # linear mapping estimation |
103 | | -eta=1e-8 # quadratic regularization for regression |
104 | | -mu=1e0 # weight of the OT linear term |
105 | | -bias=True # estimate a bias |
| 101 | +eta = 1e-8 # quadratic regularization for regression |
| 102 | +mu = 1e0 # weight of the OT linear term |
| 103 | +bias = True # estimate a bias |
106 | 104 |
|
107 | | -ot_mapping=ot.da.OTDA_mapping_linear() |
108 | | -ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) |
| 105 | +ot_mapping = ot.da.OTDA_mapping_linear() |
| 106 | +ot_mapping.fit(xs, xt, mu=mu, eta=eta, bias=bias, numItermax=20, verbose=True) |
109 | 107 |
|
110 | | -X1tl=ot_mapping.predict(X1) # use the estimated mapping |
111 | | -I1tl=minmax(mat2im(X1tl,I1.shape)) |
| 108 | +X1tl = ot_mapping.predict(X1) # use the estimated mapping |
| 109 | +I1tl = minmax(mat2im(X1tl, I1.shape)) |
112 | 110 |
|
113 | 111 | # nonlinear mapping estimation |
114 | | -eta=1e-2 # quadratic regularization for regression |
115 | | -mu=1e0 # weight of the OT linear term |
116 | | -bias=False # estimate a bias |
117 | | -sigma=1 # sigma bandwidth fot gaussian kernel |
118 | | - |
| 112 | +eta = 1e-2 # quadratic regularization for regression |
| 113 | +mu = 1e0 # weight of the OT linear term |
| 114 | +bias = False # estimate a bias |
| 115 | +sigma = 1 # sigma bandwidth fot gaussian kernel |
119 | 116 |
|
120 | | -ot_mapping_kernel=ot.da.OTDA_mapping_kernel() |
121 | | -ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) |
122 | 117 |
|
123 | | -X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping |
124 | | -I1tn=minmax(mat2im(X1tn,I1.shape)) |
125 | | -#%% plot images |
| 118 | +ot_mapping_kernel = ot.da.OTDA_mapping_kernel() |
| 119 | +ot_mapping_kernel.fit( |
| 120 | + xs, xt, mu=mu, eta=eta, sigma=sigma, bias=bias, numItermax=10, verbose=True) |
126 | 121 |
|
| 122 | +X1tn = ot_mapping_kernel.predict(X1) # use the estimated mapping |
| 123 | +I1tn = minmax(mat2im(X1tn, I1.shape)) |
127 | 124 |
|
128 | | -pl.figure(2,(10,8)) |
| 125 | +#%% plot images |
129 | 126 |
|
130 | | -pl.subplot(2,3,1) |
| 127 | +pl.figure(2, figsize=(8, 4)) |
131 | 128 |
|
| 129 | +pl.subplot(2, 3, 1) |
132 | 130 | pl.imshow(I1) |
| 131 | +pl.axis('off') |
133 | 132 | pl.title('Im. 1') |
134 | 133 |
|
135 | | -pl.subplot(2,3,2) |
136 | | - |
| 134 | +pl.subplot(2, 3, 2) |
137 | 135 | pl.imshow(I2) |
| 136 | +pl.axis('off') |
138 | 137 | pl.title('Im. 2') |
139 | 138 |
|
140 | | - |
141 | | -pl.subplot(2,3,3) |
| 139 | +pl.subplot(2, 3, 3) |
142 | 140 | pl.imshow(I1t) |
| 141 | +pl.axis('off') |
143 | 142 | pl.title('Im. 1 Interp LP') |
144 | 143 |
|
145 | | -pl.subplot(2,3,4) |
| 144 | +pl.subplot(2, 3, 4) |
146 | 145 | pl.imshow(I1te) |
| 146 | +pl.axis('off') |
147 | 147 | pl.title('Im. 1 Interp Entrop') |
148 | 148 |
|
149 | | - |
150 | | -pl.subplot(2,3,5) |
| 149 | +pl.subplot(2, 3, 5) |
151 | 150 | pl.imshow(I1tl) |
| 151 | +pl.axis('off') |
152 | 152 | pl.title('Im. 1 Linear mapping') |
153 | 153 |
|
154 | | -pl.subplot(2,3,6) |
| 154 | +pl.subplot(2, 3, 6) |
155 | 155 | pl.imshow(I1tn) |
| 156 | +pl.axis('off') |
156 | 157 | pl.title('Im. 1 nonlinear mapping') |
| 158 | +pl.tight_layout() |
157 | 159 |
|
158 | 160 | pl.show() |
0 commit comments