|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +==================== |
| 4 | +1D optimal transport |
| 5 | +==================== |
| 6 | +
|
| 7 | +@author: rflamary |
| 8 | +""" |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import matplotlib.pylab as pl |
| 12 | +import ot |
| 13 | +from ot.datasets import get_1D_gauss as gauss |
| 14 | + |
| 15 | + |
| 16 | +#%% parameters |
| 17 | + |
| 18 | +n=100 # nb bins |
| 19 | +n_target=10 # nb target distributions |
| 20 | + |
| 21 | + |
| 22 | +# bin positions |
| 23 | +x=np.arange(n,dtype=np.float64) |
| 24 | + |
| 25 | +lst_m=np.linspace(20,90,n_target) |
| 26 | + |
| 27 | +# Gaussian distributions |
| 28 | +a=gauss(n,m=20,s=5) # m= mean, s= std |
| 29 | + |
| 30 | +B=np.zeros((n,n_target)) |
| 31 | + |
| 32 | +for i,m in enumerate(lst_m): |
| 33 | + B[:,i]=gauss(n,m=m,s=5) |
| 34 | + |
| 35 | +# loss matrix |
| 36 | +M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') |
| 37 | +M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') |
| 38 | + |
| 39 | +#%% plot the distributions |
| 40 | + |
| 41 | +pl.figure(1) |
| 42 | +pl.subplot(2,1,1) |
| 43 | +pl.plot(x,a,'b',label='Source distribution') |
| 44 | +pl.title('Source distribution') |
| 45 | +pl.subplot(2,1,2) |
| 46 | +pl.plot(x,B,label='Target distributions') |
| 47 | +pl.title('Target distributions') |
| 48 | + |
| 49 | +#%% plot distributions and loss matrix |
| 50 | + |
| 51 | +emd=ot.emd2(a,B,M) |
| 52 | +emd2=ot.emd2(a,B,M2) |
| 53 | +pl.figure(2) |
| 54 | +pl.plot(emd,label='Euclidean loss') |
| 55 | +pl.plot(emd,label='Squared Euclidean loss') |
| 56 | +pl.legend() |
| 57 | + |
0 commit comments