Skip to content

Commit 3af9b06

Browse files
committed
add em2 computation example
1 parent a632c40 commit 3af9b06

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

examples/plot_compute_emd.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================
4+
1D optimal transport
5+
====================
6+
7+
@author: rflamary
8+
"""
9+
10+
import numpy as np
11+
import matplotlib.pylab as pl
12+
import ot
13+
from ot.datasets import get_1D_gauss as gauss
14+
15+
16+
#%% parameters
17+
18+
n=100 # nb bins
19+
n_target=10 # nb target distributions
20+
21+
22+
# bin positions
23+
x=np.arange(n,dtype=np.float64)
24+
25+
lst_m=np.linspace(20,90,n_target)
26+
27+
# Gaussian distributions
28+
a=gauss(n,m=20,s=5) # m= mean, s= std
29+
30+
B=np.zeros((n,n_target))
31+
32+
for i,m in enumerate(lst_m):
33+
B[:,i]=gauss(n,m=m,s=5)
34+
35+
# loss matrix
36+
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
37+
M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
38+
39+
#%% plot the distributions
40+
41+
pl.figure(1)
42+
pl.subplot(2,1,1)
43+
pl.plot(x,a,'b',label='Source distribution')
44+
pl.title('Source distribution')
45+
pl.subplot(2,1,2)
46+
pl.plot(x,B,label='Target distributions')
47+
pl.title('Target distributions')
48+
49+
#%% plot distributions and loss matrix
50+
51+
emd=ot.emd2(a,B,M)
52+
emd2=ot.emd2(a,B,M2)
53+
pl.figure(2)
54+
pl.plot(emd,label='Euclidean loss')
55+
pl.plot(emd,label='Squared Euclidean loss')
56+
pl.legend()
57+

0 commit comments

Comments
 (0)