Skip to content

Commit f639518

Browse files
committed
add example norm
1 parent 4fba2c9 commit f639518

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

examples/plot_OT_L1_vs_L2.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================================================
4+
2D Optimal transport between empirical distributions
5+
====================================================
6+
7+
Stoile the figure idea from:
8+
https://arxiv.org/pdf/1706.07650.pdf
9+
10+
11+
@author: rflamary
12+
"""
13+
14+
import numpy as np
15+
import matplotlib.pylab as pl
16+
import ot
17+
18+
#%% parameters and data generation
19+
20+
for data in range(2):
21+
22+
if data:
23+
n=20 # nb samples
24+
xs=np.zeros((n,2))
25+
xs[:,0]=np.arange(n)+1
26+
xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
27+
28+
xt=np.zeros((n,2))
29+
xt[:,1]=np.arange(n)+1
30+
else:
31+
32+
n=50 # nb samples
33+
xtot=np.zeros((n+1,2))
34+
xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
35+
xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
36+
37+
xs=xtot[:n,:]
38+
xt=xtot[1:,:]
39+
40+
41+
42+
a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
43+
44+
# loss matrix
45+
M1=ot.dist(xs,xt,metric='euclidean')
46+
M1/=M1.max()
47+
48+
# loss matrix
49+
M2=ot.dist(xs,xt,metric='sqeuclidean')
50+
M2/=M2.max()
51+
52+
# loss matrix
53+
Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
54+
Mp/=Mp.max()
55+
56+
#%% plot samples
57+
58+
pl.figure(1+3*data)
59+
pl.clf()
60+
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
61+
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
62+
pl.axis('equal')
63+
pl.title('Source and traget distributions')
64+
65+
pl.figure(2+3*data,(15,5))
66+
pl.subplot(1,3,1)
67+
pl.imshow(M1,interpolation='nearest')
68+
pl.title('Eucidean cost')
69+
pl.subplot(1,3,2)
70+
pl.imshow(M2,interpolation='nearest')
71+
pl.title('Squared Euclidean cost')
72+
73+
pl.subplot(1,3,3)
74+
pl.imshow(Mp,interpolation='nearest')
75+
pl.title('Sqrt Euclidean cost')
76+
#%% EMD
77+
78+
G1=ot.emd(a,b,M1)
79+
G2=ot.emd(a,b,M2)
80+
Gp=ot.emd(a,b,Mp)
81+
82+
pl.figure(3+3*data,(15,5))
83+
84+
pl.subplot(1,3,1)
85+
ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
86+
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
87+
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
88+
pl.axis('equal')
89+
#pl.legend(loc=0)
90+
pl.title('OT Euclidean')
91+
92+
pl.subplot(1,3,2)
93+
94+
ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
95+
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
96+
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
97+
pl.axis('equal')
98+
#pl.legend(loc=0)
99+
pl.title('OT squared Euclidean')
100+
101+
pl.subplot(1,3,3)
102+
103+
ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
104+
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
105+
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
106+
pl.axis('equal')
107+
#pl.legend(loc=0)
108+
pl.title('OT sqrt Euclidean')
109+
110+
#%% sinkhorn

0 commit comments

Comments
 (0)