Skip to content

Commit d99abf0

Browse files
committed
Wasserstein convolutional barycenter
1 parent 5180023 commit d99abf0

File tree

7 files changed

+201
-1
lines changed

7 files changed

+201
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
227227

228228
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229229

230-
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
230+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
231+
232+
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.

data/duck.png

4.99 KB
Loading

data/heart.png

5.1 KB
Loading

data/redcross.png

1.64 KB
Loading

data/tooth.png

4.82 KB
Loading
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
#%%
3+
# -*- coding: utf-8 -*-
4+
"""
5+
============================================
6+
Convolutional Wasserstein Barycenter example
7+
============================================
8+
9+
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
10+
function of POT works.
11+
"""
12+
13+
# Author: Nicolas Courty <ncourty@irisa.fr>
14+
#
15+
# License: MIT License
16+
17+
18+
import numpy as np
19+
import pylab as pl
20+
import ot
21+
22+
##############################################################################
23+
# Data preparation
24+
# ----------------
25+
#
26+
# The four distributions are constructed from 4 simple images
27+
28+
29+
f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
30+
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
31+
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
32+
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
33+
34+
A = []
35+
f1=f1/np.sum(f1)
36+
f2=f2/np.sum(f2)
37+
f3=f3/np.sum(f3)
38+
f4=f4/np.sum(f4)
39+
A.append(f1)
40+
A.append(f2)
41+
A.append(f3)
42+
A.append(f4)
43+
A=np.array(A)
44+
45+
nb_images = 5
46+
47+
# those are the four corners coordinates that will be interpolated by bilinear
48+
# interpolation
49+
v1=np.array((1,0,0,0))
50+
v2=np.array((0,1,0,0))
51+
v3=np.array((0,0,1,0))
52+
v4=np.array((0,0,0,1))
53+
54+
55+
##############################################################################
56+
# Barycenter computation and visualization
57+
# ----------------------------------------
58+
#
59+
60+
pl.figure(figsize=(10,10))
61+
pl.title('Convolutional Wasserstein Barycenters in POT')
62+
cm='Blues'
63+
# regularization parameter
64+
reg=0.004
65+
for i in range(nb_images):
66+
for j in range(nb_images):
67+
pl.subplot(nb_images,nb_images,i*nb_images+j+1)
68+
tx=float(i)/(nb_images-1)
69+
ty=float(j)/(nb_images-1)
70+
71+
# weights are constructed by bilinear interpolation
72+
tmp1=(1-tx)*v1+tx*v2
73+
tmp2=(1-tx)*v3+tx*v4
74+
weights=(1-ty)*tmp1+ty*tmp2
75+
76+
if i==0 and j==0:
77+
pl.imshow(f1,cmap=cm)
78+
pl.axis('off')
79+
elif i==0 and j==(nb_images-1):
80+
pl.imshow(f3,cmap=cm)
81+
pl.axis('off')
82+
elif i==(nb_images-1) and j==0:
83+
pl.imshow(f2,cmap=cm)
84+
pl.axis('off')
85+
elif i==(nb_images-1) and j==(nb_images-1):
86+
pl.imshow(f4,cmap=cm)
87+
pl.axis('off')
88+
else:
89+
# call to barycenter computation
90+
pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm)
91+
pl.axis('off')
92+
pl.show()

ot/bregman.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,112 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
918918
else:
919919
return geometricBar(weights, UKv)
920920

921+
def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False):
922+
"""Compute the entropic regularized wasserstein barycenter of distributions A
923+
where A is a collection of 2D images.
924+
925+
The function solves the following optimization problem:
926+
927+
.. math::
928+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
929+
930+
where :
931+
932+
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
933+
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
934+
- reg is the regularization strength scalar value
935+
936+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
937+
938+
Parameters
939+
----------
940+
A : np.ndarray (n,w,h)
941+
n distributions (2D images) of size w x h
942+
reg : float
943+
Regularization term >0
944+
weights : np.ndarray (n,)
945+
Weights of each image on the simplex (barycentric coodinates)
946+
numItermax : int, optional
947+
Max number of iterations
948+
stopThr : float, optional
949+
Stop threshol on error (>0)
950+
verbose : bool, optional
951+
Print information along iterations
952+
log : bool, optional
953+
record log if True
954+
955+
956+
Returns
957+
-------
958+
a : (w,h) ndarray
959+
2D Wasserstein barycenter
960+
log : dict
961+
log dictionary return only if log==True in parameters
962+
963+
964+
References
965+
----------
966+
967+
.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
968+
Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
969+
ACM Transactions on Graphics (TOG), 34(4), 66
970+
971+
972+
"""
973+
974+
if weights is None:
975+
weights = np.ones(A.shape[0]) / A.shape[0]
976+
else:
977+
assert(len(weights) == A.shape[0])
978+
979+
if log:
980+
log = {'err': []}
981+
982+
b=np.zeros_like(A[0,:,:])
983+
U=np.ones_like(A)
984+
KV=np.ones_like(A)
985+
threshold = 1e-30 # in order to avoids numerical precision issues
986+
987+
cpt = 0
988+
err=1
989+
990+
# build the convolution operator
991+
t = np.linspace(0,1,A.shape[1])
992+
[Y,X] = np.meshgrid(t,t)
993+
xi1 = np.exp(-(X-Y)**2/reg)
994+
K = lambda x: np.dot(np.dot(xi1,x),xi1)
995+
996+
while (err>stopThr and cpt<numItermax):
997+
998+
bold=b
999+
cpt = cpt +1
1000+
1001+
b=np.zeros_like(A[0,:,:])
1002+
for r in range(A.shape[0]):
1003+
KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:])))
1004+
b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:]))
1005+
b = np.exp(b)
1006+
for r in range(A.shape[0]):
1007+
U[r,:,:]=b/np.maximum(threshold,KV[r,:,:])
1008+
1009+
if cpt%10==1:
1010+
err=np.sum(np.abs(bold-b))
1011+
# log and verbose print
1012+
if log:
1013+
log['err'].append(err)
1014+
1015+
if verbose:
1016+
if cpt%200 ==0:
1017+
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
1018+
print('{:5d}|{:8e}|'.format(cpt,err))
1019+
1020+
if log:
1021+
log['niter']=cpt
1022+
log['U']=U
1023+
return b,log
1024+
else:
1025+
return b
1026+
9211027

9221028
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
9231029
stopThr=1e-3, verbose=False, log=False):

0 commit comments

Comments
 (0)