Skip to content

Commit 9fb56be

Browse files
committed
Merge branch 'master' into prV0.5
2 parents 697bd55 + c9b99df commit 9fb56be

File tree

13 files changed

+337
-511
lines changed

13 files changed

+337
-511
lines changed

Makefile

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22

33
PYTHON=python3
4+
branch := $(shell git symbolic-ref --short -q HEAD)
45

56
help :
67
@echo "The following make targets are available:"
@@ -57,6 +58,16 @@ rdoc :
5758
notebook :
5859
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
5960

61+
bench :
62+
@git stash >/dev/null 2>&1
63+
@echo 'Branch master'
64+
@git checkout master >/dev/null 2>&1
65+
python3 $(script)
66+
@echo 'Branch $(branch)'
67+
@git checkout $(branch) >/dev/null 2>&1
68+
python3 $(script)
69+
@git stash apply >/dev/null 2>&1
70+
6071
autopep8 :
6172
autopep8 -ir test ot examples --jobs -1
6273

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ It provides the following solvers:
1818
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
1919
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
2020
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
21-
* Non regularized free support Wasserstein barycenters [20].
2221
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
2322
* Optimal transport for domain adaptation with group lasso regularization [5]
2423
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
2524
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
2625
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
2726
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
2827
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
28+
* Non regularized free support Wasserstein barycenters [20].
2929

3030
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
3131

@@ -165,7 +165,7 @@ The contributors to this library are:
165165
* [Stanislas Chambon](https://slasnista.github.io/)
166166
* [Antoine Rolet](https://arolet.github.io/)
167167
* Erwan Vautier (Gromov-Wasserstein)
168-
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic optimization)
168+
* [Kilian Fatras](https://kilianfatras.github.io/)
169169

170170
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
171171

@@ -224,8 +224,10 @@ You can also post bug reports and feature requests in Github issues. Make sure t
224224

225225
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
226226

227-
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
227+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016).
228228

229229
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
230230

231-
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
231+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
232+
233+
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.

data/duck.png

4.99 KB
Loading

data/heart.png

5.1 KB
Loading

data/redcross.png

1.64 KB
Loading

data/tooth.png

4.82 KB
Loading
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
#%%
3+
# -*- coding: utf-8 -*-
4+
"""
5+
============================================
6+
Convolutional Wasserstein Barycenter example
7+
============================================
8+
9+
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
10+
function of POT works.
11+
"""
12+
13+
# Author: Nicolas Courty <ncourty@irisa.fr>
14+
#
15+
# License: MIT License
16+
17+
18+
import numpy as np
19+
import pylab as pl
20+
import ot
21+
22+
##############################################################################
23+
# Data preparation
24+
# ----------------
25+
#
26+
# The four distributions are constructed from 4 simple images
27+
28+
29+
f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
30+
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
31+
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
32+
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
33+
34+
A = []
35+
f1 = f1 / np.sum(f1)
36+
f2 = f2 / np.sum(f2)
37+
f3 = f3 / np.sum(f3)
38+
f4 = f4 / np.sum(f4)
39+
A.append(f1)
40+
A.append(f2)
41+
A.append(f3)
42+
A.append(f4)
43+
A = np.array(A)
44+
45+
nb_images = 5
46+
47+
# those are the four corners coordinates that will be interpolated by bilinear
48+
# interpolation
49+
v1 = np.array((1, 0, 0, 0))
50+
v2 = np.array((0, 1, 0, 0))
51+
v3 = np.array((0, 0, 1, 0))
52+
v4 = np.array((0, 0, 0, 1))
53+
54+
55+
##############################################################################
56+
# Barycenter computation and visualization
57+
# ----------------------------------------
58+
#
59+
60+
pl.figure(figsize=(10, 10))
61+
pl.title('Convolutional Wasserstein Barycenters in POT')
62+
cm = 'Blues'
63+
# regularization parameter
64+
reg = 0.004
65+
for i in range(nb_images):
66+
for j in range(nb_images):
67+
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
68+
tx = float(i) / (nb_images - 1)
69+
ty = float(j) / (nb_images - 1)
70+
71+
# weights are constructed by bilinear interpolation
72+
tmp1 = (1 - tx) * v1 + tx * v2
73+
tmp2 = (1 - tx) * v3 + tx * v4
74+
weights = (1 - ty) * tmp1 + ty * tmp2
75+
76+
if i == 0 and j == 0:
77+
pl.imshow(f1, cmap=cm)
78+
pl.axis('off')
79+
elif i == 0 and j == (nb_images - 1):
80+
pl.imshow(f3, cmap=cm)
81+
pl.axis('off')
82+
elif i == (nb_images - 1) and j == 0:
83+
pl.imshow(f2, cmap=cm)
84+
pl.axis('off')
85+
elif i == (nb_images - 1) and j == (nb_images - 1):
86+
pl.imshow(f4, cmap=cm)
87+
pl.axis('off')
88+
else:
89+
# call to barycenter computation
90+
pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
91+
pl.axis('off')
92+
pl.show()

ot/bregman.py

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
350350
np.exp(K, out=K)
351351

352352
# print(np.min(K))
353-
tmp = np.empty(K.shape, dtype=M.dtype)
354353
tmp2 = np.empty(b.shape, dtype=M.dtype)
355354

356355
Kp = (1 / a).reshape(-1, 1) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
359358
while (err > stopThr and cpt < numItermax):
360359
uprev = u
361360
vprev = v
361+
362362
KtransposeU = np.dot(K.T, u)
363363
v = np.divide(b, KtransposeU)
364364
u = 1. / np.dot(Kp, v)
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
379379
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
380380
np.sum((v - vprev)**2) / np.sum((v)**2)
381381
else:
382-
np.multiply(u.reshape(-1, 1), K, out=tmp)
383-
np.multiply(tmp, v.reshape(1, -1), out=tmp)
384-
np.sum(tmp, axis=0, out=tmp2)
385-
tmp2 -= b
386-
err = np.linalg.norm(tmp2)**2
382+
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
383+
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
384+
err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
387385
if log:
388386
log['err'].append(err)
389387

@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
398396
log['v'] = v
399397

400398
if nbb: # return only loss
401-
res = np.zeros((nbb))
402-
for i in range(nbb):
403-
res[i] = np.sum(
404-
u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
399+
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
405400
if log:
406401
return res, log
407402
else:
@@ -924,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
924919
return geometricBar(weights, UKv)
925920

926921

922+
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
923+
"""Compute the entropic regularized wasserstein barycenter of distributions A
924+
where A is a collection of 2D images.
925+
926+
The function solves the following optimization problem:
927+
928+
.. math::
929+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
930+
931+
where :
932+
933+
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
934+
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
935+
- reg is the regularization strength scalar value
936+
937+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
938+
939+
Parameters
940+
----------
941+
A : np.ndarray (n,w,h)
942+
n distributions (2D images) of size w x h
943+
reg : float
944+
Regularization term >0
945+
weights : np.ndarray (n,)
946+
Weights of each image on the simplex (barycentric coodinates)
947+
numItermax : int, optional
948+
Max number of iterations
949+
stopThr : float, optional
950+
Stop threshol on error (>0)
951+
stabThr : float, optional
952+
Stabilization threshold to avoid numerical precision issue
953+
verbose : bool, optional
954+
Print information along iterations
955+
log : bool, optional
956+
record log if True
957+
958+
959+
Returns
960+
-------
961+
a : (w,h) ndarray
962+
2D Wasserstein barycenter
963+
log : dict
964+
log dictionary return only if log==True in parameters
965+
966+
967+
References
968+
----------
969+
970+
.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
971+
Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
972+
ACM Transactions on Graphics (TOG), 34(4), 66
973+
974+
975+
"""
976+
977+
if weights is None:
978+
weights = np.ones(A.shape[0]) / A.shape[0]
979+
else:
980+
assert(len(weights) == A.shape[0])
981+
982+
if log:
983+
log = {'err': []}
984+
985+
b = np.zeros_like(A[0, :, :])
986+
U = np.ones_like(A)
987+
KV = np.ones_like(A)
988+
989+
cpt = 0
990+
err = 1
991+
992+
# build the convolution operator
993+
t = np.linspace(0, 1, A.shape[1])
994+
[Y, X] = np.meshgrid(t, t)
995+
xi1 = np.exp(-(X - Y)**2 / reg)
996+
997+
def K(x):
998+
return np.dot(np.dot(xi1, x), xi1)
999+
1000+
while (err > stopThr and cpt < numItermax):
1001+
1002+
bold = b
1003+
cpt = cpt + 1
1004+
1005+
b = np.zeros_like(A[0, :, :])
1006+
for r in range(A.shape[0]):
1007+
KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
1008+
b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
1009+
b = np.exp(b)
1010+
for r in range(A.shape[0]):
1011+
U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
1012+
1013+
if cpt % 10 == 1:
1014+
err = np.sum(np.abs(bold - b))
1015+
# log and verbose print
1016+
if log:
1017+
log['err'].append(err)
1018+
1019+
if verbose:
1020+
if cpt % 200 == 0:
1021+
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
1022+
print('{:5d}|{:8e}|'.format(cpt, err))
1023+
1024+
if log:
1025+
log['niter'] = cpt
1026+
log['U'] = U
1027+
return b, log
1028+
else:
1029+
return b
1030+
1031+
9271032
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
9281033
stopThr=1e-3, verbose=False, log=False):
9291034
"""

0 commit comments

Comments
 (0)