Skip to content

Commit e458b7a

Browse files
committed
add doc for gallery
1 parent 7609f9e commit e458b7a

File tree

77 files changed

+4163
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+4163
-1
lines changed
36.1 KB
Binary file not shown.
24.8 KB
Binary file not shown.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"nbformat_minor": 0,
3+
"nbformat": 4,
4+
"cells": [
5+
{
6+
"execution_count": null,
7+
"cell_type": "code",
8+
"source": [
9+
"%matplotlib inline"
10+
],
11+
"outputs": [],
12+
"metadata": {
13+
"collapsed": false
14+
}
15+
},
16+
{
17+
"source": [
18+
"\nDemo for 1D optimal transport\n\n@author: rflamary\n\n"
19+
],
20+
"cell_type": "markdown",
21+
"metadata": {}
22+
},
23+
{
24+
"execution_count": null,
25+
"cell_type": "code",
26+
"source": [
27+
"import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=n*.2,s=5) # m= mean, s= std\nb=gauss(n,m=n*.6,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')\n\n#%% Sinkhorn\n\nlambd=1e-4\nGss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)\nGss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')\n\n#%% Sinkhorn\n\nlambd=1e-11\nGss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')"
28+
],
29+
"outputs": [],
30+
"metadata": {
31+
"collapsed": false
32+
}
33+
}
34+
],
35+
"metadata": {
36+
"kernelspec": {
37+
"display_name": "Python 2",
38+
"name": "python2",
39+
"language": "python"
40+
},
41+
"language_info": {
42+
"mimetype": "text/x-python",
43+
"nbconvert_exporter": "python",
44+
"name": "python",
45+
"file_extension": ".py",
46+
"version": "2.7.12",
47+
"pygments_lexer": "ipython2",
48+
"codemirror_mode": {
49+
"version": 2,
50+
"name": "ipython"
51+
}
52+
}
53+
}
54+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Demo for 1D optimal transport
4+
5+
@author: rflamary
6+
"""
7+
8+
import numpy as np
9+
import matplotlib.pylab as pl
10+
import ot
11+
from ot.datasets import get_1D_gauss as gauss
12+
13+
14+
#%% parameters
15+
16+
n=100 # nb bins
17+
18+
# bin positions
19+
x=np.arange(n,dtype=np.float64)
20+
21+
# Gaussian distributions
22+
a=gauss(n,m=n*.2,s=5) # m= mean, s= std
23+
b=gauss(n,m=n*.6,s=10)
24+
25+
# loss matrix
26+
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
27+
M/=M.max()
28+
29+
#%% plot the distributions
30+
31+
pl.figure(1)
32+
pl.plot(x,a,'b',label='Source distribution')
33+
pl.plot(x,b,'r',label='Target distribution')
34+
pl.legend()
35+
36+
#%% plot distributions and loss matrix
37+
38+
pl.figure(2)
39+
ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
40+
41+
#%% EMD
42+
43+
G0=ot.emd(a,b,M)
44+
45+
pl.figure(3)
46+
ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
47+
48+
#%% Sinkhorn
49+
50+
lambd=1e-3
51+
Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
52+
53+
pl.figure(4)
54+
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
55+
56+
#%% Sinkhorn
57+
58+
lambd=1e-4
59+
Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)
60+
Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])
61+
62+
pl.figure(5)
63+
ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
64+
65+
#%% Sinkhorn
66+
67+
lambd=1e-11
68+
Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)
69+
70+
pl.figure(5)
71+
ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
3+
.. _sphx_glr_auto_examples_demo_OT_1D_test.py:
4+
5+
6+
Demo for 1D optimal transport
7+
8+
@author: rflamary
9+
10+
11+
12+
.. code-block:: python
13+
14+
15+
import numpy as np
16+
import matplotlib.pylab as pl
17+
import ot
18+
from ot.datasets import get_1D_gauss as gauss
19+
20+
21+
#%% parameters
22+
23+
n=100 # nb bins
24+
25+
# bin positions
26+
x=np.arange(n,dtype=np.float64)
27+
28+
# Gaussian distributions
29+
a=gauss(n,m=n*.2,s=5) # m= mean, s= std
30+
b=gauss(n,m=n*.6,s=10)
31+
32+
# loss matrix
33+
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
34+
M/=M.max()
35+
36+
#%% plot the distributions
37+
38+
pl.figure(1)
39+
pl.plot(x,a,'b',label='Source distribution')
40+
pl.plot(x,b,'r',label='Target distribution')
41+
pl.legend()
42+
43+
#%% plot distributions and loss matrix
44+
45+
pl.figure(2)
46+
ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
47+
48+
#%% EMD
49+
50+
G0=ot.emd(a,b,M)
51+
52+
pl.figure(3)
53+
ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
54+
55+
#%% Sinkhorn
56+
57+
lambd=1e-3
58+
Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
59+
60+
pl.figure(4)
61+
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
62+
63+
#%% Sinkhorn
64+
65+
lambd=1e-4
66+
Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)
67+
Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])
68+
69+
pl.figure(5)
70+
ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
71+
72+
#%% Sinkhorn
73+
74+
lambd=1e-11
75+
Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)
76+
77+
pl.figure(5)
78+
ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
79+
80+
**Total running time of the script:** ( 0 minutes 0.000 seconds)
81+
82+
83+
84+
.. container:: sphx-glr-footer
85+
86+
87+
.. container:: sphx-glr-download
88+
89+
:download:`Download Python source code: demo_OT_1D_test.py <demo_OT_1D_test.py>`
90+
91+
92+
93+
.. container:: sphx-glr-download
94+
95+
:download:`Download Jupyter notebook: demo_OT_1D_test.ipynb <demo_OT_1D_test.ipynb>`
96+
97+
.. rst-class:: sphx-glr-signature
98+
99+
`Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"nbformat_minor": 0,
3+
"nbformat": 4,
4+
"cells": [
5+
{
6+
"execution_count": null,
7+
"cell_type": "code",
8+
"source": [
9+
"%matplotlib inline"
10+
],
11+
"outputs": [],
12+
"metadata": {
13+
"collapsed": false
14+
}
15+
},
16+
{
17+
"source": [
18+
"\nDemo for 2D Optimal transport between empirical distributions\n\n@author: rflamary\n\n"
19+
],
20+
"cell_type": "markdown",
21+
"metadata": {}
22+
},
23+
{
24+
"execution_count": null,
25+
"cell_type": "code",
26+
"source": [
27+
"import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=5000 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\n#pl.figure(1)\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('Source and traget distributions')\n#\n#pl.figure(2)\n#pl.imshow(M,interpolation='nearest')\n#pl.title('Cost matrix M')\n#\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\n#pl.figure(3)\n#pl.imshow(G0,interpolation='nearest')\n#pl.title('OT matrix G0')\n#\n#pl.figure(4)\n#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\n#pl.figure(5)\n#pl.imshow(Gs,interpolation='nearest')\n#pl.title('OT matrix sinkhorn')\n#\n#pl.figure(6)\n#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix Sinkhorn with samples')\n#"
28+
],
29+
"outputs": [],
30+
"metadata": {
31+
"collapsed": false
32+
}
33+
}
34+
],
35+
"metadata": {
36+
"kernelspec": {
37+
"display_name": "Python 2",
38+
"name": "python2",
39+
"language": "python"
40+
},
41+
"language_info": {
42+
"mimetype": "text/x-python",
43+
"nbconvert_exporter": "python",
44+
"name": "python",
45+
"file_extension": ".py",
46+
"version": "2.7.12",
47+
"pygments_lexer": "ipython2",
48+
"codemirror_mode": {
49+
"version": 2,
50+
"name": "ipython"
51+
}
52+
}
53+
}
54+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Demo for 2D Optimal transport between empirical distributions
4+
5+
@author: rflamary
6+
"""
7+
8+
import numpy as np
9+
import matplotlib.pylab as pl
10+
import ot
11+
12+
#%% parameters and data generation
13+
14+
n=5000 # nb samples
15+
16+
mu_s=np.array([0,0])
17+
cov_s=np.array([[1,0],[0,1]])
18+
19+
mu_t=np.array([4,4])
20+
cov_t=np.array([[1,-.8],[-.8,1]])
21+
22+
xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
23+
xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
24+
25+
a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
26+
27+
# loss matrix
28+
M=ot.dist(xs,xt)
29+
M/=M.max()
30+
31+
#%% plot samples
32+
33+
#pl.figure(1)
34+
#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
35+
#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
36+
#pl.legend(loc=0)
37+
#pl.title('Source and traget distributions')
38+
#
39+
#pl.figure(2)
40+
#pl.imshow(M,interpolation='nearest')
41+
#pl.title('Cost matrix M')
42+
#
43+
44+
#%% EMD
45+
46+
G0=ot.emd(a,b,M)
47+
48+
#pl.figure(3)
49+
#pl.imshow(G0,interpolation='nearest')
50+
#pl.title('OT matrix G0')
51+
#
52+
#pl.figure(4)
53+
#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
54+
#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
55+
#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
56+
#pl.legend(loc=0)
57+
#pl.title('OT matrix with samples')
58+
59+
60+
#%% sinkhorn
61+
62+
# reg term
63+
lambd=5e-3
64+
65+
Gs=ot.sinkhorn(a,b,M,lambd)
66+
67+
#pl.figure(5)
68+
#pl.imshow(Gs,interpolation='nearest')
69+
#pl.title('OT matrix sinkhorn')
70+
#
71+
#pl.figure(6)
72+
#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
73+
#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
74+
#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
75+
#pl.legend(loc=0)
76+
#pl.title('OT matrix Sinkhorn with samples')
77+
#
78+

0 commit comments

Comments
 (0)