Skip to content

Commit 7ad4725

Browse files
committed
more
1 parent 8239423 commit 7ad4725

File tree

2 files changed

+70
-68
lines changed

2 files changed

+70
-68
lines changed

examples/plot_OTDA_mapping_color_images.py

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,147 +12,149 @@
1212
"""
1313

1414
import numpy as np
15-
import scipy.ndimage as spi
15+
from scipy import ndimage
1616
import matplotlib.pylab as pl
1717
import ot
1818

1919

2020
#%% Loading images
2121

22-
I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
23-
I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
22+
I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
23+
I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
2424

2525
#%% Plot images
2626

27-
pl.figure(1)
28-
29-
pl.subplot(1,2,1)
27+
pl.figure(1, figsize=(6.4, 3))
28+
pl.subplot(1, 2, 1)
3029
pl.imshow(I1)
30+
pl.axis('off')
3131
pl.title('Image 1')
3232

33-
pl.subplot(1,2,2)
33+
pl.subplot(1, 2, 2)
3434
pl.imshow(I2)
35+
pl.axis('off')
3536
pl.title('Image 2')
36-
37-
pl.show()
37+
pl.tight_layout()
3838

3939
#%% Image conversion and dataset generation
4040

4141
def im2mat(I):
4242
"""Converts and image to matrix (one pixel per line)"""
43-
return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
43+
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
44+
4445

45-
def mat2im(X,shape):
46+
def mat2im(X, shape):
4647
"""Converts back a matrix to an image"""
4748
return X.reshape(shape)
4849

49-
X1=im2mat(I1)
50-
X2=im2mat(I2)
50+
X1 = im2mat(I1)
51+
X2 = im2mat(I2)
5152

5253
# training samples
53-
nb=1000
54-
idx1=np.random.randint(X1.shape[0],size=(nb,))
55-
idx2=np.random.randint(X2.shape[0],size=(nb,))
54+
nb = 1000
55+
idx1 = np.random.randint(X1.shape[0], size=(nb,))
56+
idx2 = np.random.randint(X2.shape[0], size=(nb,))
5657

57-
xs=X1[idx1,:]
58-
xt=X2[idx2,:]
58+
xs = X1[idx1, :]
59+
xt = X2[idx2, :]
5960

6061
#%% Plot image distributions
6162

6263

63-
pl.figure(2,(10,5))
64+
pl.figure(2, figsize=(6.4, 5))
6465

65-
pl.subplot(1,2,1)
66-
pl.scatter(xs[:,0],xs[:,2],c=xs)
67-
pl.axis([0,1,0,1])
66+
pl.subplot(1, 2, 1)
67+
pl.scatter(xs[:, 0], xs[:, 2], c=xs)
68+
pl.axis([0, 1, 0, 1])
6869
pl.xlabel('Red')
6970
pl.ylabel('Blue')
7071
pl.title('Image 1')
7172

72-
pl.subplot(1,2,2)
73-
#pl.imshow(I2)
74-
pl.scatter(xt[:,0],xt[:,2],c=xt)
75-
pl.axis([0,1,0,1])
73+
pl.subplot(1, 2, 2)
74+
pl.scatter(xt[:, 0], xt[:, 2], c=xt)
75+
pl.axis([0, 1, 0, 1])
7676
pl.xlabel('Red')
7777
pl.ylabel('Blue')
7878
pl.title('Image 2')
79-
80-
pl.show()
81-
82-
79+
pl.tight_layout()
8380

8481
#%% domain adaptation between images
8582
def minmax(I):
86-
return np.minimum(np.maximum(I,0),1)
83+
return np.clip(I, 0, 1)
84+
8785
# LP problem
88-
da_emd=ot.da.OTDA() # init class
89-
da_emd.fit(xs,xt) # fit distributions
86+
da_emd = ot.da.OTDA() # init class
87+
da_emd.fit(xs, xt) # fit distributions
9088

91-
X1t=da_emd.predict(X1) # out of sample
92-
I1t=minmax(mat2im(X1t,I1.shape))
89+
X1t = da_emd.predict(X1) # out of sample
90+
I1t = minmax(mat2im(X1t, I1.shape))
9391

9492
# sinkhorn regularization
95-
lambd=1e-1
96-
da_entrop=ot.da.OTDA_sinkhorn()
97-
da_entrop.fit(xs,xt,reg=lambd)
93+
lambd = 1e-1
94+
da_entrop = ot.da.OTDA_sinkhorn()
95+
da_entrop.fit(xs, xt, reg=lambd)
9896

99-
X1te=da_entrop.predict(X1)
100-
I1te=minmax(mat2im(X1te,I1.shape))
97+
X1te = da_entrop.predict(X1)
98+
I1te = minmax(mat2im(X1te, I1.shape))
10199

102100
# linear mapping estimation
103-
eta=1e-8 # quadratic regularization for regression
104-
mu=1e0 # weight of the OT linear term
105-
bias=True # estimate a bias
101+
eta = 1e-8 # quadratic regularization for regression
102+
mu = 1e0 # weight of the OT linear term
103+
bias = True # estimate a bias
106104

107-
ot_mapping=ot.da.OTDA_mapping_linear()
108-
ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
105+
ot_mapping = ot.da.OTDA_mapping_linear()
106+
ot_mapping.fit(xs, xt, mu=mu, eta=eta, bias=bias, numItermax=20, verbose=True)
109107

110-
X1tl=ot_mapping.predict(X1) # use the estimated mapping
111-
I1tl=minmax(mat2im(X1tl,I1.shape))
108+
X1tl = ot_mapping.predict(X1) # use the estimated mapping
109+
I1tl = minmax(mat2im(X1tl, I1.shape))
112110

113111
# nonlinear mapping estimation
114-
eta=1e-2 # quadratic regularization for regression
115-
mu=1e0 # weight of the OT linear term
116-
bias=False # estimate a bias
117-
sigma=1 # sigma bandwidth fot gaussian kernel
118-
112+
eta = 1e-2 # quadratic regularization for regression
113+
mu = 1e0 # weight of the OT linear term
114+
bias = False # estimate a bias
115+
sigma = 1 # sigma bandwidth fot gaussian kernel
119116

120-
ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
121-
ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
122117

123-
X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping
124-
I1tn=minmax(mat2im(X1tn,I1.shape))
125-
#%% plot images
118+
ot_mapping_kernel = ot.da.OTDA_mapping_kernel()
119+
ot_mapping_kernel.fit(
120+
xs, xt, mu=mu, eta=eta, sigma=sigma, bias=bias, numItermax=10, verbose=True)
126121

122+
X1tn = ot_mapping_kernel.predict(X1) # use the estimated mapping
123+
I1tn = minmax(mat2im(X1tn, I1.shape))
127124

128-
pl.figure(2,(10,8))
125+
#%% plot images
129126

130-
pl.subplot(2,3,1)
127+
pl.figure(2, figsize=(8, 4))
131128

129+
pl.subplot(2, 3, 1)
132130
pl.imshow(I1)
131+
pl.axis('off')
133132
pl.title('Im. 1')
134133

135-
pl.subplot(2,3,2)
136-
134+
pl.subplot(2, 3, 2)
137135
pl.imshow(I2)
136+
pl.axis('off')
138137
pl.title('Im. 2')
139138

140-
141-
pl.subplot(2,3,3)
139+
pl.subplot(2, 3, 3)
142140
pl.imshow(I1t)
141+
pl.axis('off')
143142
pl.title('Im. 1 Interp LP')
144143

145-
pl.subplot(2,3,4)
144+
pl.subplot(2, 3, 4)
146145
pl.imshow(I1te)
146+
pl.axis('off')
147147
pl.title('Im. 1 Interp Entrop')
148148

149-
150-
pl.subplot(2,3,5)
149+
pl.subplot(2, 3, 5)
151150
pl.imshow(I1tl)
151+
pl.axis('off')
152152
pl.title('Im. 1 Linear mapping')
153153

154-
pl.subplot(2,3,6)
154+
pl.subplot(2, 3, 6)
155155
pl.imshow(I1tn)
156+
pl.axis('off')
156157
pl.title('Im. 1 nonlinear mapping')
158+
pl.tight_layout()
157159

158160
pl.show()

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ description-file = README.md
33

44
[flake8]
55
exclude = __init__.py
6-
ignore = E265
6+
ignore = E265,E501

0 commit comments

Comments
 (0)