Skip to content

Commit 9f63ee9

Browse files
committed
initial commit partial Wass and GW
1 parent 81e9d42 commit 9f63ee9

File tree

5 files changed

+1239
-19
lines changed

5 files changed

+1239
-19
lines changed

docs/source/all.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,9 @@ ot.unbalanced
8686

8787
.. automodule:: ot.unbalanced
8888
:members:
89+
90+
ot.partial
91+
-------------
92+
93+
.. automodule:: ot.partial
94+
:members:

docs/source/readme.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,14 @@ of the 36th International Conference on Machine Learning (ICML).
391391
`Learning with a Wasserstein Loss <http://cbcl.mit.edu/wasserstein/>`__
392392
Advances in Neural Information Processing Systems (NIPS).
393393

394+
[26] Caffarelli, L. A., McCann, R. J. (2020). `Free boundaries in optimal transport and
395+
Monge-Ampere obstacle problems <http://www.math.toronto.edu/~mccann/papers/annals2010.pdf>`__,
396+
Annals of mathematics, 673-730.
397+
398+
[27] Chapel, L., Alaya, M., Gasso, G. (2019). `Partial Gromov-Wasserstein with Applications
399+
on Positive-Unlabeled Learning <https://arxiv.org/abs/2002.08276>`__. arXiv preprint
400+
arXiv:2002.08276.
401+
394402
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
395403
:target: https://badge.fury.io/py/POT
396404
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==========================
4+
Partial Wasserstein and Gromov-Wasserstein example
5+
==========================
6+
7+
This example is designed to show how to use the Partial (Gromov-)Wassertsein
8+
distance computation in POT.
9+
"""
10+
11+
# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
12+
# License: MIT License
13+
14+
import scipy as sp
15+
import numpy as np
16+
import matplotlib.pylab as pl
17+
import ot
18+
19+
20+
#############################################################################
21+
#
22+
# Sample two 2D Gaussian distributions and plot them
23+
# --------------------------------------------------
24+
#
25+
# For demonstration purpose, we sample two Gaussian distributions in 2-d
26+
# spaces and add some random noise.
27+
28+
29+
n_samples = 20 # nb samples (gaussian)
30+
n_noise = 20 # nb of samples (noise)
31+
32+
mu = np.array([0, 0])
33+
cov = np.array([[1, 0], [0, 2]])
34+
35+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
36+
xs = np.append(xs, (np.random.rand(n_noise, 2)+1)*4).reshape((-1, 2))
37+
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
38+
xt = np.append(xt, (np.random.rand(n_noise, 2)+1)*-3).reshape((-1, 2))
39+
40+
M = sp.spatial.distance.cdist(xs, xt)
41+
42+
fig = pl.figure()
43+
ax1 = fig.add_subplot(131)
44+
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
45+
ax2 = fig.add_subplot(132)
46+
ax2.scatter(xt[:, 0], xt[:, 1], color='r')
47+
ax3 = fig.add_subplot(133)
48+
ax3.imshow(M)
49+
pl.show()
50+
51+
#############################################################################
52+
#
53+
# Compute partial Wasserstein plans and distance,
54+
# by transporting 50% of the mass
55+
# ----------------------------------------------
56+
57+
p = ot.unif(n_samples + n_noise)
58+
q = ot.unif(n_samples + n_noise)
59+
60+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
61+
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
62+
log=True)
63+
64+
print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
65+
print('Entropic partial Wasserstein distance (m = 0.5): ' + \
66+
str(log['partial_w_dist']))
67+
68+
pl.figure(1, (10, 5))
69+
pl.subplot(1, 2, 1)
70+
pl.imshow(w0, cmap='jet')
71+
pl.title('Partial Wasserstein')
72+
pl.subplot(1, 2, 2)
73+
pl.imshow(w, cmap='jet')
74+
pl.title('Entropic partial Wasserstein')
75+
pl.show()
76+
77+
78+
#############################################################################
79+
#
80+
# Sample one 2D and 3D Gaussian distributions and plot them
81+
# ---------------------------------------------------------
82+
#
83+
# The Gromov-Wasserstein distance allows to compute distances with samples that
84+
# do not belong to the same metric space. For demonstration purpose, we sample
85+
# two Gaussian distributions in 2- and 3-dimensional spaces.
86+
87+
n_samples = 20 # nb samples
88+
n_noise = 10 # nb of samples (noise)
89+
90+
p = ot.unif(n_samples + n_noise)
91+
q = ot.unif(n_samples + n_noise)
92+
93+
mu_s = np.array([0, 0])
94+
cov_s = np.array([[1, 0], [0, 1]])
95+
96+
mu_t = np.array([0, 0, 0])
97+
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
98+
99+
100+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
101+
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2)+1)*4)), axis=0)
102+
P = sp.linalg.sqrtm(cov_t)
103+
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
104+
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3)+1)*10)), axis=0)
105+
106+
fig = pl.figure()
107+
ax1 = fig.add_subplot(121)
108+
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
109+
ax2 = fig.add_subplot(122, projection='3d')
110+
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
111+
pl.show()
112+
113+
114+
#############################################################################
115+
#
116+
# Compute partial Gromov-Wasserstein plans and distance,
117+
# by transporting 100% and 2/3 of the mass
118+
# -----------------------------------------------------
119+
120+
C1 = sp.spatial.distance.cdist(xs, xs)
121+
C2 = sp.spatial.distance.cdist(xt, xt)
122+
123+
print('-----m = 1')
124+
m = 1
125+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
126+
log=True)
127+
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
128+
m=m, log=True)
129+
130+
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
131+
print('Entropic partial Wasserstein distance (m = 1): ' + \
132+
str(log['partial_gw_dist']))
133+
134+
pl.figure(1, (10, 5))
135+
pl.title("mass to be transported m = 1")
136+
pl.subplot(1, 2, 1)
137+
pl.imshow(res0, cmap='jet')
138+
pl.title('Partial Wasserstein')
139+
pl.subplot(1, 2, 2)
140+
pl.imshow(res, cmap='jet')
141+
pl.title('Entropic partial Wasserstein')
142+
pl.show()
143+
144+
print('-----m = 2/3')
145+
m = 2/3
146+
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
147+
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
148+
m=m, log=True)
149+
150+
print('Partial Wasserstein distance (m = 2/3): ' + \
151+
str(log0['partial_gw_dist']))
152+
print('Entropic partial Wasserstein distance (m = 2/3): ' + \
153+
str(log['partial_gw_dist']))
154+
155+
pl.figure(1, (10, 5))
156+
pl.title("mass to be transported m = 2/3")
157+
pl.subplot(1, 2, 1)
158+
pl.imshow(res0, cmap='jet')
159+
pl.title('Partial Wasserstein')
160+
pl.subplot(1, 2, 2)
161+
pl.imshow(res, cmap='jet')
162+
pl.title('Entropic partial Wasserstein')
163+
pl.show()

0 commit comments

Comments
 (0)