Skip to content

Commit 549b95b

Browse files
committed
FGW+gromov changes
1 parent 327b0c6 commit 549b95b

File tree

5 files changed

+546
-22
lines changed

5 files changed

+546
-22
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
219219
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
220220

221221
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
222+
223+
[18] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).

examples/plot_fgw.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
==============================
5+
Plot Fused-gromov-Wasserstein
6+
==============================
7+
8+
This example illustrates the computation of FGW for 1D measures[18].
9+
10+
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
11+
and Courty Nicolas
12+
"Optimal Transport for structured data with application on graphs"
13+
International Conference on Machine Learning (ICML). 2019.
14+
15+
"""
16+
17+
# Author: Titouan Vayer <titouan.vayer@irisa.fr>
18+
#
19+
# License: MIT License
20+
21+
import matplotlib.pyplot as pl
22+
import numpy as np
23+
import ot
24+
from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein
25+
26+
#%% parameters
27+
# We create two 1D random measures
28+
n=20
29+
n2=30
30+
sig=1
31+
sig2=0.1
32+
33+
np.random.seed(0)
34+
35+
phi=np.arange(n)[:,None]
36+
xs=phi+sig*np.random.randn(n,1)
37+
ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1)
38+
39+
phi2=np.arange(n2)[:,None]
40+
xt=phi2+sig*np.random.randn(n2,1)
41+
yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1)
42+
yt= yt[::-1,:]
43+
44+
p=ot.unif(n)
45+
q=ot.unif(n2)
46+
47+
#%% plot the distributions
48+
49+
pl.close(10)
50+
pl.figure(10,(7,7))
51+
52+
pl.subplot(2,1,1)
53+
54+
pl.scatter(ys,xs,c=phi,s=70)
55+
pl.ylabel('Feature value a',fontsize=20)
56+
pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1)
57+
pl.xticks(())
58+
pl.yticks(())
59+
pl.subplot(2,1,2)
60+
pl.scatter(yt,xt,c=phi2,s=70)
61+
pl.xlabel('coordinates x/y',fontsize=25)
62+
pl.ylabel('Feature value b',fontsize=20)
63+
pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1)
64+
pl.yticks(())
65+
pl.tight_layout()
66+
pl.show()
67+
68+
69+
#%% Structure matrices and across-features distance matrix
70+
C1=ot.dist(xs)
71+
C2=ot.dist(xt).T
72+
M=ot.dist(ys,yt)
73+
w1=ot.unif(C1.shape[0])
74+
w2=ot.unif(C2.shape[0])
75+
Got=ot.emd([],[],M)
76+
77+
#%%
78+
cmap='Reds'
79+
pl.close(10)
80+
pl.figure(10,(5,5))
81+
fs=15
82+
l_x=[0,5,10,15]
83+
l_y=[0,5,10,15,20,25]
84+
gs = pl.GridSpec(5, 5)
85+
86+
ax1=pl.subplot(gs[3:,:2])
87+
88+
pl.imshow(C1,cmap=cmap,interpolation='nearest')
89+
pl.title("$C_1$",fontsize=fs)
90+
pl.xlabel("$k$",fontsize=fs)
91+
pl.ylabel("$i$",fontsize=fs)
92+
pl.xticks(l_x)
93+
pl.yticks(l_x)
94+
95+
ax2=pl.subplot(gs[:3,2:])
96+
97+
pl.imshow(C2,cmap=cmap,interpolation='nearest')
98+
pl.title("$C_2$",fontsize=fs)
99+
pl.ylabel("$l$",fontsize=fs)
100+
#pl.ylabel("$l$",fontsize=fs)
101+
pl.xticks(())
102+
pl.yticks(l_y)
103+
ax2.set_aspect('auto')
104+
105+
ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1)
106+
pl.imshow(M,cmap=cmap,interpolation='nearest')
107+
pl.yticks(l_x)
108+
pl.xticks(l_y)
109+
pl.ylabel("$i$",fontsize=fs)
110+
pl.title("$M_{AB}$",fontsize=fs)
111+
pl.xlabel("$j$",fontsize=fs)
112+
pl.tight_layout()
113+
ax3.set_aspect('auto')
114+
pl.show()
115+
116+
117+
#%% Computing FGW and GW
118+
alpha=1e-3
119+
120+
ot.tic()
121+
Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True)
122+
ot.toc()
123+
124+
#%reload_ext WGW
125+
Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True)
126+
127+
#%% visu OT matrix
128+
cmap='Blues'
129+
fs=15
130+
pl.figure(2,(13,5))
131+
pl.clf()
132+
pl.subplot(1,3,1)
133+
pl.imshow(Got,cmap=cmap,interpolation='nearest')
134+
#pl.xlabel("$y$",fontsize=fs)
135+
pl.ylabel("$i$",fontsize=fs)
136+
pl.xticks(())
137+
138+
pl.title('Wasserstein ($M$ only)')
139+
140+
pl.subplot(1,3,2)
141+
pl.imshow(Gg,cmap=cmap,interpolation='nearest')
142+
pl.title('Gromov ($C_1,C_2$ only)')
143+
pl.xticks(())
144+
pl.subplot(1,3,3)
145+
pl.imshow(Gwg,cmap=cmap,interpolation='nearest')
146+
pl.title('FGW ($M+C_1,C_2$)')
147+
148+
pl.xlabel("$j$",fontsize=fs)
149+
pl.ylabel("$i$",fontsize=fs)
150+
151+
pl.tight_layout()
152+
pl.show()

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Author: Remi Flamary <remi.flamary@unice.fr>
77
# Nicolas Courty <ncourty@irisa.fr>
8-
#
8+
# Titouan Vayer <titouan.vayer@irisa.fr>
99
# License: MIT License
1010

1111
import numpy as np

0 commit comments

Comments
 (0)