Skip to content

Commit b1b514f

Browse files
committed
bary fgw
1 parent 549b95b commit b1b514f

File tree

3 files changed

+180
-8
lines changed

3 files changed

+180
-8
lines changed

examples/plot_barycenter_fgw.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================
4+
Plot graphs' barycenter using FGW
5+
=================================
6+
7+
This example illustrates the computation barycenter of labeled graphs using FGW
8+
9+
Requires networkx >=2
10+
11+
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
12+
and Courty Nicolas
13+
"Optimal Transport for structured data with application on graphs"
14+
International Conference on Machine Learning (ICML). 2019.
15+
16+
"""
17+
18+
# Author: Titouan Vayer <titouan.vayer@irisa.fr>
19+
#
20+
# License: MIT License
21+
22+
#%% load libraries
23+
import numpy as np
24+
import matplotlib.pyplot as plt
25+
import networkx as nx
26+
import math
27+
from scipy.sparse.csgraph import shortest_path
28+
import matplotlib.colors as mcol
29+
from matplotlib import cm
30+
from ot.gromov import fgw_barycenters
31+
#%% Graph functions
32+
33+
def find_thresh(C,inf=0.5,sup=3,step=10):
34+
""" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
35+
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
36+
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
37+
and the original matrix.
38+
Parameters
39+
----------
40+
C : ndarray, shape (n_nodes,n_nodes)
41+
The structure matrix to threshold
42+
inf : float
43+
The beginning of the linesearch
44+
sup : float
45+
The end of the linesearch
46+
step : integer
47+
Number of thresholds tested
48+
"""
49+
dist=[]
50+
search=np.linspace(inf,sup,step)
51+
for thresh in search:
52+
Cprime=sp_to_adjency(C,0,thresh)
53+
SC=shortest_path(Cprime,method='D')
54+
SC[SC==float('inf')]=100
55+
dist.append(np.linalg.norm(SC-C))
56+
return search[np.argmin(dist)],dist
57+
58+
def sp_to_adjency(C,threshinf=0.2,threshsup=1.8):
59+
""" Thresholds the structure matrix in order to compute an adjency matrix.
60+
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
61+
Parameters
62+
----------
63+
C : ndarray, shape (n_nodes,n_nodes)
64+
The structure matrix to threshold
65+
threshinf : float
66+
The minimum value of distance from which the new value is set to 1
67+
threshsup : float
68+
The maximum value of distance from which the new value is set to 1
69+
Returns
70+
-------
71+
C : ndarray, shape (n_nodes,n_nodes)
72+
The threshold matrix. Each element is in {0,1}
73+
"""
74+
H=np.zeros_like(C)
75+
np.fill_diagonal(H,np.diagonal(C))
76+
C=C-H
77+
C=np.minimum(np.maximum(C,threshinf),threshsup)
78+
C[C==threshsup]=0
79+
C[C!=0]=1
80+
81+
return C
82+
83+
def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None):
84+
""" Create a noisy circular graph
85+
"""
86+
g=nx.Graph()
87+
g.add_nodes_from(list(range(N)))
88+
for i in range(N):
89+
noise=float(np.random.normal(mu,sigma,1))
90+
if with_noise:
91+
g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise)
92+
else:
93+
g.add_node(i,attr_name=math.sin(2*i*math.pi/N))
94+
g.add_edge(i,i+1)
95+
if structure_noise:
96+
randomint=np.random.randint(0,p)
97+
if randomint==0:
98+
if i<=N-3:
99+
g.add_edge(i,i+2)
100+
if i==N-2:
101+
g.add_edge(i,0)
102+
if i==N-1:
103+
g.add_edge(i,1)
104+
g.add_edge(N,0)
105+
noise=float(np.random.normal(mu,sigma,1))
106+
if with_noise:
107+
g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise)
108+
else:
109+
g.add_node(N,attr_name=math.sin(2*N*math.pi/N))
110+
return g
111+
112+
def graph_colors(nx_graph,vmin=0,vmax=7):
113+
cnorm = mcol.Normalize(vmin=vmin,vmax=vmax)
114+
cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis')
115+
cpick.set_array([])
116+
val_map = {}
117+
for k,v in nx.get_node_attributes(nx_graph,'attr_name').items():
118+
val_map[k]=cpick.to_rgba(v)
119+
colors=[]
120+
for node in nx_graph.nodes():
121+
colors.append(val_map[node])
122+
return colors
123+
124+
#%% create dataset
125+
# We build a dataset of noisy circular graphs.
126+
# Noise is added on the structures by random connections and on the features by gaussian noise.
127+
128+
np.random.seed(30)
129+
X0=[]
130+
for k in range(9):
131+
X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3))
132+
133+
#%% Plot dataset
134+
135+
plt.figure(figsize=(8,10))
136+
for i in range(len(X0)):
137+
plt.subplot(3,3,i+1)
138+
g=X0[i]
139+
pos=nx.kamada_kawai_layout(g)
140+
nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100)
141+
plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20)
142+
plt.show()
143+
144+
145+
146+
#%%
147+
# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
148+
# Features distances are the euclidean distances
149+
Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0]
150+
ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0]
151+
Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0]
152+
lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel()
153+
sizebary=15 # we choose a barycenter with 15 nodes
154+
155+
#%%
156+
157+
A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95)
158+
159+
#%%
160+
bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0]))
161+
for i in range(len(A.ravel())):
162+
bary.add_node(i,attr_name=float(A.ravel()[i]))
163+
164+
#%%
165+
pos = nx.kamada_kawai_layout(bary)
166+
nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False)
167+
plt.suptitle('Barycenter',fontsize=20)
168+
plt.show()
169+
170+
171+
172+

examples/plot_fgw.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
# -*- coding: utf-8 -*-
32
"""
43
==============================

ot/gromov.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
883883

884884
return C
885885

886-
def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss',
887-
max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None):
886+
def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,
887+
p=None,loss_fun='square_loss',max_iter=100, tol=1e-9,
888+
verbose=False,log=True,init_C=None,init_X=None):
888889

889890
"""
890891
Compute the fgw barycenter as presented eq (5) in [3].
@@ -957,7 +958,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
957958
X=np.zeros((N,d))
958959
else:
959960
X = init_X
960-
961+
961962
T=[np.outer(p,q) for q in ps]
962963

963964
# X is N,d
@@ -981,7 +982,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
981982

982983
if not fixed_features:
983984
Ys_temp=[y.T for y in Ys]
984-
X=update_feature_matrix(lambdas,Ys_temp,T,p)
985+
X=update_feature_matrix(lambdas,Ys_temp,T,p).T
985986

986987
# X must be N,d
987988
# Ys must be ns,d
@@ -1024,11 +1025,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
10241025
print('{:5d}|{:8e}|'.format(cpt, err_feature))
10251026

10261027
cpt += 1
1027-
log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys
1028+
log_['T']=T # from target to Ys
10281029
log_['p']=p
1029-
log_['Ms']=Ms #Ms sont de tailles N,ns
1030+
log_['Ms']=Ms #Ms are N,ns
10301031

1031-
return X.T,C,log_
1032+
return X,C,log_
10321033

10331034

10341035
def update_sructure_matrix(p, lambdas, T, Cs):

0 commit comments

Comments
 (0)