Skip to content

Commit c8eda44

Browse files
author
Kilian Fatras
committed
add problems solved in doc
1 parent 90efa5a commit c8eda44

File tree

4 files changed

+667
-0
lines changed

4 files changed

+667
-0
lines changed

examples/plot_stochastic.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
==========================
3+
Stochastic examples
4+
==========================
5+
6+
This example is designed to show how to use the stochatic optimization
7+
algorithms for descrete and semicontinous measures from the POT library.
8+
9+
"""
10+
11+
# Author: Kilian Fatras <kilian.fatras@gmail.com>
12+
#
13+
# License: MIT License
14+
15+
import matplotlib.pylab as pl
16+
import numpy as np
17+
import ot
18+
19+
20+
#############################################################################
21+
# COMPUTE TRANSPORTATION MATRIX
22+
#############################################################################
23+
24+
#############################################################################
25+
# DISCRETE CASE
26+
# Sample two discrete measures for the discrete case
27+
# ---------------------------------------------
28+
#
29+
# Define 2 discrete measures a and b, the points where are defined the source
30+
# and the target measures and finally the cost matrix c.
31+
32+
n_source = 7
33+
n_target = 4
34+
reg = 1
35+
numItermax = 10000
36+
lr = 0.1
37+
38+
a = ot.utils.unif(n_source)
39+
b = ot.utils.unif(n_target)
40+
41+
rng = np.random.RandomState(0)
42+
X_source = rng.randn(n_source, 2)
43+
Y_target = rng.randn(n_target, 2)
44+
M = ot.dist(X_source, Y_target)
45+
46+
#############################################################################
47+
#
48+
# Call the "SAG" method to find the transportation matrix in the discrete case
49+
# ---------------------------------------------
50+
#
51+
# Define the method "SAG", call ot.transportation_matrix_entropic and plot the
52+
# results.
53+
54+
method = "SAG"
55+
sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
56+
numItermax, lr)
57+
print(sag_pi)
58+
59+
#############################################################################
60+
# SEMICONTINOUS CASE
61+
# Sample one general measure a, one discrete measures b for the semicontinous
62+
# case
63+
# ---------------------------------------------
64+
#
65+
# Define one general measure a, one discrete measures b, the points where
66+
# are defined the source and the target measures and finally the cost matrix c.
67+
68+
n_source = 7
69+
n_target = 4
70+
reg = 1
71+
numItermax = 500000
72+
lr = 1
73+
74+
a = ot.utils.unif(n_source)
75+
b = ot.utils.unif(n_target)
76+
77+
rng = np.random.RandomState(0)
78+
X_source = rng.randn(n_source, 2)
79+
Y_target = rng.randn(n_target, 2)
80+
M = ot.dist(X_source, Y_target)
81+
82+
#############################################################################
83+
#
84+
# Call the "ASGD" method to find the transportation matrix in the semicontinous
85+
# case
86+
# ---------------------------------------------
87+
#
88+
# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the
89+
# results.
90+
91+
method = "ASGD"
92+
asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
93+
numItermax, lr)
94+
print(asgd_pi)
95+
96+
#############################################################################
97+
#
98+
# Compare the results with the Sinkhorn algorithm
99+
# ---------------------------------------------
100+
#
101+
# Call the Sinkhorn algorithm from POT
102+
103+
sinkhorn_pi = ot.sinkhorn(a, b, M, 1)
104+
print(sinkhorn_pi)
105+
106+
107+
##############################################################################
108+
# PLOT TRANSPORTATION MATRIX
109+
##############################################################################
110+
111+
##############################################################################
112+
# Plot SAG results
113+
# ----------------
114+
115+
pl.figure(4, figsize=(5, 5))
116+
ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG')
117+
pl.show()
118+
119+
120+
##############################################################################
121+
# Plot ASGD results
122+
# -----------------
123+
124+
pl.figure(4, figsize=(5, 5))
125+
ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD')
126+
pl.show()
127+
128+
129+
##############################################################################
130+
# Plot Sinkhorn results
131+
# ---------------------
132+
133+
pl.figure(4, figsize=(5, 5))
134+
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
135+
pl.show()

ot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from . import datasets
1919
from . import da
2020
from . import gromov
21+
from . import smooth
22+
from . import stochastic
2123

2224
# OT functions
2325
from .lp import emd, emd2

0 commit comments

Comments
 (0)