Skip to content

Commit 212f388

Browse files
committed
update tests
1 parent ec67362 commit 212f388

File tree

9 files changed

+148
-57
lines changed

9 files changed

+148
-57
lines changed

examples/plot_OT_1D.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
1D optimal transport
55
====================
66
7+
This example illustrate the computation of EMD and Sinkhorn transport plans
8+
and their visualization.
9+
710
"""
811

912
# Author: Remi Flamary <remi.flamary@unice.fr>

examples/plot_OT_2D_samples.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import matplotlib.pylab as pl
1515
import ot
1616

17+
##############################################################################
18+
# Generate data
19+
##############################################################################
20+
1721
#%% parameters and data generation
1822

1923
n = 50 # nb samples
@@ -33,6 +37,10 @@
3337
M = ot.dist(xs, xt)
3438
M /= M.max()
3539

40+
##############################################################################
41+
# Plot data
42+
##############################################################################
43+
3644
#%% plot samples
3745

3846
pl.figure(1)
@@ -45,6 +53,9 @@
4553
pl.imshow(M, interpolation='nearest')
4654
pl.title('Cost matrix M')
4755

56+
##############################################################################
57+
# Compute EMD
58+
##############################################################################
4859

4960
#%% EMD
5061

@@ -62,6 +73,10 @@
6273
pl.title('OT matrix with samples')
6374

6475

76+
##############################################################################
77+
# Compute Sinkhorn
78+
##############################################################################
79+
6580
#%% sinkhorn
6681

6782
# reg term

examples/plot_WDA.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
Wasserstein Discriminant Analysis
55
=================================
66
7+
This example illustrate the use of WDA as proposed in [11].
8+
9+
10+
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
11+
Wasserstein Discriminant Analysis.
12+
713
"""
814

915
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -16,6 +22,10 @@
1622
from ot.dr import wda, fda
1723

1824

25+
##############################################################################
26+
# Generate data
27+
##############################################################################
28+
1929
#%% parameters
2030

2131
n = 1000 # nb samples in source and target datasets
@@ -39,6 +49,10 @@
3949
xs = np.hstack((xs, np.random.randn(n, nbnoise)))
4050
xt = np.hstack((xt, np.random.randn(n, nbnoise)))
4151

52+
##############################################################################
53+
# Plot data
54+
##############################################################################
55+
4256
#%% plot samples
4357
pl.figure(1, figsize=(6.4, 3.5))
4458

@@ -53,11 +67,19 @@
5367
pl.title('Other dimensions')
5468
pl.tight_layout()
5569

70+
##############################################################################
71+
# Compute Fisher Discriminant Analysis
72+
##############################################################################
73+
5674
#%% Compute FDA
5775
p = 2
5876

5977
Pfda, projfda = fda(xs, ys, p)
6078

79+
##############################################################################
80+
# Compute Wasserstein Discriminant Analysis
81+
##############################################################################
82+
6183
#%% Compute WDA
6284
p = 2
6385
reg = 1e0
@@ -66,6 +88,11 @@
6688

6789
Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
6890

91+
92+
##############################################################################
93+
# Plot 2D projections
94+
##############################################################################
95+
6996
#%% plot samples
7097

7198
xsp = projfda(xs)

examples/plot_barycenter_1D.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
1D Wasserstein barycenter demo
55
==============================
66
7+
This example illustrate the computation of regularized Wassersyein Barycenter
8+
as proposed in [3].
9+
10+
11+
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
12+
Iterative Bregman projections for regularized transportation problems
13+
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
14+
715
"""
816

917
# Author: Remi Flamary <remi.flamary@unice.fr>

examples/plot_compute_emd.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
====================
4-
1D optimal transport
5-
====================
3+
=================
4+
Plot multiple EMD
5+
=================
66
77
"""
88

@@ -16,6 +16,10 @@
1616
from ot.datasets import get_1D_gauss as gauss
1717

1818

19+
##############################################################################
20+
# Generate data
21+
##############################################################################
22+
1923
#%% parameters
2024

2125
n = 100 # nb bins
@@ -40,6 +44,11 @@
4044
M /= M.max()
4145
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
4246
M2 /= M2.max()
47+
48+
##############################################################################
49+
# Plot data
50+
##############################################################################
51+
4352
#%% plot the distributions
4453

4554
pl.figure(1)
@@ -51,10 +60,15 @@
5160
pl.title('Target distributions')
5261
pl.tight_layout()
5362

63+
64+
##############################################################################
65+
# Compute EMD for the different losses
66+
##############################################################################
67+
5468
#%% Compute and plot distributions and loss matrix
5569

5670
d_emd = ot.emd2(a, B, M) # direct computation of EMD
57-
d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3
71+
d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
5872

5973

6074
pl.figure(2)
@@ -63,6 +77,10 @@
6377
pl.title('EMD distances')
6478
pl.legend()
6579

80+
##############################################################################
81+
# Compute Sinkhorn for the different losses
82+
##############################################################################
83+
6684
#%%
6785
reg = 1e-2
6886
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)

examples/plot_optim_OTreg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@
44
Regularized OT with generic solver
55
==================================
66
7+
This example illustrate the use of the generic solver for regularized OT with
8+
user designed regularization term. It uses Conditional gradient as in [6] and
9+
generalized Conditional Gradient as proposed in [5][7].
10+
11+
12+
[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
13+
Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
14+
Intelligence , vol.PP, no.99, pp.1-1.
15+
16+
[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
17+
Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
18+
7(3), 1853-1882.
19+
20+
[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
21+
conditional gradient: analysis of convergence and applications.
22+
arXiv preprint arXiv:1510.06567.
23+
24+
725
826
"""
927

examples/plot_otda_color_images.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
========================================================
4-
OT for domain adaptation with image color adaptation [6]
5-
========================================================
3+
=============================
4+
OT for image color adaptation
5+
=============================
66
77
This example presents a way of transferring colors between two image
88
with Optimal Transport as introduced in [6]
@@ -41,7 +41,7 @@ def minmax(I):
4141

4242

4343
##############################################################################
44-
# generate data
44+
# Generate data
4545
##############################################################################
4646

4747
# Loading images
@@ -61,33 +61,7 @@ def minmax(I):
6161

6262

6363
##############################################################################
64-
# Instantiate the different transport algorithms and fit them
65-
##############################################################################
66-
67-
# EMDTransport
68-
ot_emd = ot.da.EMDTransport()
69-
ot_emd.fit(Xs=Xs, Xt=Xt)
70-
71-
# SinkhornTransport
72-
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
73-
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
74-
75-
# prediction between images (using out of sample prediction as in [6])
76-
transp_Xs_emd = ot_emd.transform(Xs=X1)
77-
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
78-
79-
transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
80-
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
81-
82-
I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
83-
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
84-
85-
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
86-
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
87-
88-
89-
##############################################################################
90-
# plot original image
64+
# Plot original image
9165
##############################################################################
9266

9367
pl.figure(1, figsize=(6.4, 3))
@@ -104,7 +78,7 @@ def minmax(I):
10478

10579

10680
##############################################################################
107-
# scatter plot of colors
81+
# Scatter plot of colors
10882
##############################################################################
10983

11084
pl.figure(2, figsize=(6.4, 3))
@@ -126,7 +100,33 @@ def minmax(I):
126100

127101

128102
##############################################################################
129-
# plot new images
103+
# Instantiate the different transport algorithms and fit them
104+
##############################################################################
105+
106+
# EMDTransport
107+
ot_emd = ot.da.EMDTransport()
108+
ot_emd.fit(Xs=Xs, Xt=Xt)
109+
110+
# SinkhornTransport
111+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
112+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
113+
114+
# prediction between images (using out of sample prediction as in [6])
115+
transp_Xs_emd = ot_emd.transform(Xs=X1)
116+
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
117+
118+
transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
119+
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
120+
121+
I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
122+
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
123+
124+
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
125+
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
126+
127+
128+
##############################################################################
129+
# Plot new images
130130
##############################################################################
131131

132132
pl.figure(3, figsize=(8, 4))

examples/plot_otda_mapping.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
This example presents how to use MappingTransport to estimate at the same
88
time both the coupling transport and approximate the transport map with either
9-
a linear or a kernelized mapping as introduced in [8]
9+
a linear or a kernelized mapping as introduced in [8].
1010
1111
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
1212
"Mapping estimation for discrete optimal transport",
@@ -43,6 +43,17 @@
4343
Xt[yt == 2] *= 3
4444
Xt = Xt + 4
4545

46+
##############################################################################
47+
# plot data
48+
##############################################################################
49+
50+
pl.figure(1, (10, 5))
51+
pl.clf()
52+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
53+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
54+
pl.legend(loc=0)
55+
pl.title('Source and target distributions')
56+
4657

4758
##############################################################################
4859
# Instantiate the different transport algorithms and fit them
@@ -76,19 +87,7 @@
7687

7788

7889
##############################################################################
79-
# plot data
80-
##############################################################################
81-
82-
pl.figure(1, (10, 5))
83-
pl.clf()
84-
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
85-
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
86-
pl.legend(loc=0)
87-
pl.title('Source and target distributions')
88-
89-
90-
##############################################################################
91-
# plot transported samples
90+
# Plot transported samples
9291
##############################################################################
9392

9493
pl.figure(2)

0 commit comments

Comments
 (0)