Skip to content

Commit 414331c

Browse files
committed
Merge readme with master
2 parents 75fe96c + c9b99df commit 414331c

File tree

9 files changed

+233
-5
lines changed

9 files changed

+233
-5
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
230230

231231
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
232232

233-
[21] J. Altschuler, J.Weed, P. Rigollet, (2017) Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31
233+
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
234+
235+
[21] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31

data/duck.png

4.99 KB
Loading

data/heart.png

5.1 KB
Loading

data/redcross.png

1.64 KB
Loading

data/tooth.png

4.82 KB
Loading
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
#%%
3+
# -*- coding: utf-8 -*-
4+
"""
5+
============================================
6+
Convolutional Wasserstein Barycenter example
7+
============================================
8+
9+
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
10+
function of POT works.
11+
"""
12+
13+
# Author: Nicolas Courty <ncourty@irisa.fr>
14+
#
15+
# License: MIT License
16+
17+
18+
import numpy as np
19+
import pylab as pl
20+
import ot
21+
22+
##############################################################################
23+
# Data preparation
24+
# ----------------
25+
#
26+
# The four distributions are constructed from 4 simple images
27+
28+
29+
f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
30+
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
31+
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
32+
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
33+
34+
A = []
35+
f1 = f1 / np.sum(f1)
36+
f2 = f2 / np.sum(f2)
37+
f3 = f3 / np.sum(f3)
38+
f4 = f4 / np.sum(f4)
39+
A.append(f1)
40+
A.append(f2)
41+
A.append(f3)
42+
A.append(f4)
43+
A = np.array(A)
44+
45+
nb_images = 5
46+
47+
# those are the four corners coordinates that will be interpolated by bilinear
48+
# interpolation
49+
v1 = np.array((1, 0, 0, 0))
50+
v2 = np.array((0, 1, 0, 0))
51+
v3 = np.array((0, 0, 1, 0))
52+
v4 = np.array((0, 0, 0, 1))
53+
54+
55+
##############################################################################
56+
# Barycenter computation and visualization
57+
# ----------------------------------------
58+
#
59+
60+
pl.figure(figsize=(10, 10))
61+
pl.title('Convolutional Wasserstein Barycenters in POT')
62+
cm = 'Blues'
63+
# regularization parameter
64+
reg = 0.004
65+
for i in range(nb_images):
66+
for j in range(nb_images):
67+
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
68+
tx = float(i) / (nb_images - 1)
69+
ty = float(j) / (nb_images - 1)
70+
71+
# weights are constructed by bilinear interpolation
72+
tmp1 = (1 - tx) * v1 + tx * v2
73+
tmp2 = (1 - tx) * v3 + tx * v4
74+
weights = (1 - ty) * tmp1 + ty * tmp2
75+
76+
if i == 0 and j == 0:
77+
pl.imshow(f1, cmap=cm)
78+
pl.axis('off')
79+
elif i == 0 and j == (nb_images - 1):
80+
pl.imshow(f3, cmap=cm)
81+
pl.axis('off')
82+
elif i == (nb_images - 1) and j == 0:
83+
pl.imshow(f2, cmap=cm)
84+
pl.axis('off')
85+
elif i == (nb_images - 1) and j == (nb_images - 1):
86+
pl.imshow(f4, cmap=cm)
87+
pl.axis('off')
88+
else:
89+
# call to barycenter computation
90+
pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
91+
pl.axis('off')
92+
pl.show()

ot/bregman.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
10701070
return geometricBar(weights, UKv)
10711071

10721072

1073+
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
1074+
"""Compute the entropic regularized wasserstein barycenter of distributions A
1075+
where A is a collection of 2D images.
1076+
1077+
The function solves the following optimization problem:
1078+
1079+
.. math::
1080+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
1081+
1082+
where :
1083+
1084+
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
1085+
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
1086+
- reg is the regularization strength scalar value
1087+
1088+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
1089+
1090+
Parameters
1091+
----------
1092+
A : np.ndarray (n,w,h)
1093+
n distributions (2D images) of size w x h
1094+
reg : float
1095+
Regularization term >0
1096+
weights : np.ndarray (n,)
1097+
Weights of each image on the simplex (barycentric coodinates)
1098+
numItermax : int, optional
1099+
Max number of iterations
1100+
stopThr : float, optional
1101+
Stop threshol on error (>0)
1102+
stabThr : float, optional
1103+
Stabilization threshold to avoid numerical precision issue
1104+
verbose : bool, optional
1105+
Print information along iterations
1106+
log : bool, optional
1107+
record log if True
1108+
1109+
1110+
Returns
1111+
-------
1112+
a : (w,h) ndarray
1113+
2D Wasserstein barycenter
1114+
log : dict
1115+
log dictionary return only if log==True in parameters
1116+
1117+
1118+
References
1119+
----------
1120+
1121+
.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
1122+
Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
1123+
ACM Transactions on Graphics (TOG), 34(4), 66
1124+
1125+
1126+
"""
1127+
1128+
if weights is None:
1129+
weights = np.ones(A.shape[0]) / A.shape[0]
1130+
else:
1131+
assert(len(weights) == A.shape[0])
1132+
1133+
if log:
1134+
log = {'err': []}
1135+
1136+
b = np.zeros_like(A[0, :, :])
1137+
U = np.ones_like(A)
1138+
KV = np.ones_like(A)
1139+
1140+
cpt = 0
1141+
err = 1
1142+
1143+
# build the convolution operator
1144+
t = np.linspace(0, 1, A.shape[1])
1145+
[Y, X] = np.meshgrid(t, t)
1146+
xi1 = np.exp(-(X - Y)**2 / reg)
1147+
1148+
def K(x):
1149+
return np.dot(np.dot(xi1, x), xi1)
1150+
1151+
while (err > stopThr and cpt < numItermax):
1152+
1153+
bold = b
1154+
cpt = cpt + 1
1155+
1156+
b = np.zeros_like(A[0, :, :])
1157+
for r in range(A.shape[0]):
1158+
KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
1159+
b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
1160+
b = np.exp(b)
1161+
for r in range(A.shape[0]):
1162+
U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
1163+
1164+
if cpt % 10 == 1:
1165+
err = np.sum(np.abs(bold - b))
1166+
# log and verbose print
1167+
if log:
1168+
log['err'].append(err)
1169+
1170+
if verbose:
1171+
if cpt % 200 == 0:
1172+
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
1173+
print('{:5d}|{:8e}|'.format(cpt, err))
1174+
1175+
if log:
1176+
log['niter'] = cpt
1177+
log['U'] = U
1178+
return b, log
1179+
else:
1180+
return b
1181+
1182+
10731183
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
10741184
stopThr=1e-3, verbose=False, log=False):
10751185
"""

test/test_bregman.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,30 @@ def test_bary():
108108
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
109109

110110

111+
def test_wassersteinbary():
112+
113+
size = 100 # size of a square image
114+
a1 = np.random.randn(size, size)
115+
a1 += a1.min()
116+
a1 = a1 / np.sum(a1)
117+
a2 = np.random.randn(size, size)
118+
a2 += a2.min()
119+
a2 = a2 / np.sum(a2)
120+
# creating matrix A containing all distributions
121+
A = np.zeros((2, 100, 100))
122+
A[0, :, :] = a1
123+
A[1, :, :] = a2
124+
125+
# wasserstein
126+
reg = 1e-3
127+
bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
128+
129+
np.testing.assert_allclose(1, np.sum(bary_wass))
130+
131+
# help in checking if log and verbose do not bug the function
132+
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
133+
134+
111135
def test_unmix():
112136

113137
n_bins = 50 # nb bins

test/test_stochastic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_stochastic_sag():
3232
# test sag
3333
n = 15
3434
reg = 1
35-
numItermax = 300000
35+
numItermax = 30000
3636
rng = np.random.RandomState(0)
3737

3838
x = rng.randn(n, 2)
@@ -62,7 +62,7 @@ def test_stochastic_asgd():
6262
# test asgd
6363
n = 15
6464
reg = 1
65-
numItermax = 300000
65+
numItermax = 100000
6666
rng = np.random.RandomState(0)
6767

6868
x = rng.randn(n, 2)
@@ -92,7 +92,7 @@ def test_sag_asgd_sinkhorn():
9292
# test all algorithms
9393
n = 15
9494
reg = 1
95-
nb_iter = 300000
95+
nb_iter = 100000
9696
rng = np.random.RandomState(0)
9797

9898
x = rng.randn(n, 2)
@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
167167
# test all dual algorithms
168168
n = 10
169169
reg = 1
170-
nb_iter = 150000
170+
nb_iter = 15000
171171
batch_size = 10
172172
rng = np.random.RandomState(0)
173173

0 commit comments

Comments
 (0)