Skip to content

Commit 10f9b0d

Browse files
committed
add example file for smooth OT
1 parent 1d34716 commit 10f9b0d

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

examples/plot_OT_1D_smooth.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================
4+
1D optimal transport
5+
====================
6+
7+
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
8+
and their visualization.
9+
10+
"""
11+
12+
# Author: Remi Flamary <remi.flamary@unice.fr>
13+
#
14+
# License: MIT License
15+
16+
import numpy as np
17+
import matplotlib.pylab as pl
18+
import ot
19+
import ot.plot
20+
from ot.datasets import get_1D_gauss as gauss
21+
22+
##############################################################################
23+
# Generate data
24+
# -------------
25+
26+
27+
#%% parameters
28+
29+
n = 100 # nb bins
30+
31+
# bin positions
32+
x = np.arange(n, dtype=np.float64)
33+
34+
# Gaussian distributions
35+
a = gauss(n, m=20, s=5) # m= mean, s= std
36+
b = gauss(n, m=60, s=10)
37+
38+
# loss matrix
39+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
40+
M /= M.max()
41+
42+
43+
##############################################################################
44+
# Plot distributions and loss matrix
45+
# ----------------------------------
46+
47+
#%% plot the distributions
48+
49+
pl.figure(1, figsize=(6.4, 3))
50+
pl.plot(x, a, 'b', label='Source distribution')
51+
pl.plot(x, b, 'r', label='Target distribution')
52+
pl.legend()
53+
54+
#%% plot distributions and loss matrix
55+
56+
pl.figure(2, figsize=(5, 5))
57+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
58+
59+
##############################################################################
60+
# Solve EMD
61+
# ---------
62+
63+
64+
#%% EMD
65+
66+
G0 = ot.emd(a, b, M)
67+
68+
pl.figure(3, figsize=(5, 5))
69+
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
70+
71+
##############################################################################
72+
# Solve Sinkhorn
73+
# --------------
74+
75+
76+
#%% Sinkhorn
77+
78+
lambd = 2e-3
79+
Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
80+
81+
pl.figure(4, figsize=(5, 5))
82+
ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
83+
84+
pl.show()
85+
86+
##############################################################################
87+
# Solve Smooth OT
88+
# --------------
89+
90+
91+
#%% Smooth OT with KL regularization
92+
93+
lambd = 2e-3
94+
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
95+
96+
pl.figure(5, figsize=(5, 5))
97+
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
98+
99+
pl.show()
100+
101+
102+
#%% Smooth OT with KL regularization
103+
104+
lambd = 1e-1
105+
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
106+
107+
pl.figure(6, figsize=(5, 5))
108+
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
109+
110+
pl.show()

0 commit comments

Comments
 (0)