|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +================================= |
| 4 | +Plot graphs' barycenter using FGW |
| 5 | +================================= |
| 6 | +
|
| 7 | +This example illustrates the computation barycenter of labeled graphs using FGW |
| 8 | +
|
| 9 | +Requires networkx >=2 |
| 10 | +
|
| 11 | +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain |
| 12 | + and Courty Nicolas |
| 13 | + "Optimal Transport for structured data with application on graphs" |
| 14 | + International Conference on Machine Learning (ICML). 2019. |
| 15 | +
|
| 16 | +""" |
| 17 | + |
| 18 | +# Author: Titouan Vayer <titouan.vayer@irisa.fr> |
| 19 | +# |
| 20 | +# License: MIT License |
| 21 | + |
| 22 | +#%% load libraries |
| 23 | +import numpy as np |
| 24 | +import matplotlib.pyplot as plt |
| 25 | +import networkx as nx |
| 26 | +import math |
| 27 | +from scipy.sparse.csgraph import shortest_path |
| 28 | +import matplotlib.colors as mcol |
| 29 | +from matplotlib import cm |
| 30 | +from ot.gromov import fgw_barycenters |
| 31 | +#%% Graph functions |
| 32 | + |
| 33 | +def find_thresh(C,inf=0.5,sup=3,step=10): |
| 34 | + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected |
| 35 | + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. |
| 36 | + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix |
| 37 | + and the original matrix. |
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + C : ndarray, shape (n_nodes,n_nodes) |
| 41 | + The structure matrix to threshold |
| 42 | + inf : float |
| 43 | + The beginning of the linesearch |
| 44 | + sup : float |
| 45 | + The end of the linesearch |
| 46 | + step : integer |
| 47 | + Number of thresholds tested |
| 48 | + """ |
| 49 | + dist=[] |
| 50 | + search=np.linspace(inf,sup,step) |
| 51 | + for thresh in search: |
| 52 | + Cprime=sp_to_adjency(C,0,thresh) |
| 53 | + SC=shortest_path(Cprime,method='D') |
| 54 | + SC[SC==float('inf')]=100 |
| 55 | + dist.append(np.linalg.norm(SC-C)) |
| 56 | + return search[np.argmin(dist)],dist |
| 57 | + |
| 58 | +def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): |
| 59 | + """ Thresholds the structure matrix in order to compute an adjency matrix. |
| 60 | + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 |
| 61 | + Parameters |
| 62 | + ---------- |
| 63 | + C : ndarray, shape (n_nodes,n_nodes) |
| 64 | + The structure matrix to threshold |
| 65 | + threshinf : float |
| 66 | + The minimum value of distance from which the new value is set to 1 |
| 67 | + threshsup : float |
| 68 | + The maximum value of distance from which the new value is set to 1 |
| 69 | + Returns |
| 70 | + ------- |
| 71 | + C : ndarray, shape (n_nodes,n_nodes) |
| 72 | + The threshold matrix. Each element is in {0,1} |
| 73 | + """ |
| 74 | + H=np.zeros_like(C) |
| 75 | + np.fill_diagonal(H,np.diagonal(C)) |
| 76 | + C=C-H |
| 77 | + C=np.minimum(np.maximum(C,threshinf),threshsup) |
| 78 | + C[C==threshsup]=0 |
| 79 | + C[C!=0]=1 |
| 80 | + |
| 81 | + return C |
| 82 | + |
| 83 | +def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): |
| 84 | + """ Create a noisy circular graph |
| 85 | + """ |
| 86 | + g=nx.Graph() |
| 87 | + g.add_nodes_from(list(range(N))) |
| 88 | + for i in range(N): |
| 89 | + noise=float(np.random.normal(mu,sigma,1)) |
| 90 | + if with_noise: |
| 91 | + g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) |
| 92 | + else: |
| 93 | + g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) |
| 94 | + g.add_edge(i,i+1) |
| 95 | + if structure_noise: |
| 96 | + randomint=np.random.randint(0,p) |
| 97 | + if randomint==0: |
| 98 | + if i<=N-3: |
| 99 | + g.add_edge(i,i+2) |
| 100 | + if i==N-2: |
| 101 | + g.add_edge(i,0) |
| 102 | + if i==N-1: |
| 103 | + g.add_edge(i,1) |
| 104 | + g.add_edge(N,0) |
| 105 | + noise=float(np.random.normal(mu,sigma,1)) |
| 106 | + if with_noise: |
| 107 | + g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) |
| 108 | + else: |
| 109 | + g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) |
| 110 | + return g |
| 111 | + |
| 112 | +def graph_colors(nx_graph,vmin=0,vmax=7): |
| 113 | + cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) |
| 114 | + cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') |
| 115 | + cpick.set_array([]) |
| 116 | + val_map = {} |
| 117 | + for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): |
| 118 | + val_map[k]=cpick.to_rgba(v) |
| 119 | + colors=[] |
| 120 | + for node in nx_graph.nodes(): |
| 121 | + colors.append(val_map[node]) |
| 122 | + return colors |
| 123 | + |
| 124 | +#%% create dataset |
| 125 | +# We build a dataset of noisy circular graphs. |
| 126 | +# Noise is added on the structures by random connections and on the features by gaussian noise. |
| 127 | + |
| 128 | +np.random.seed(30) |
| 129 | +X0=[] |
| 130 | +for k in range(9): |
| 131 | + X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) |
| 132 | + |
| 133 | +#%% Plot dataset |
| 134 | + |
| 135 | +plt.figure(figsize=(8,10)) |
| 136 | +for i in range(len(X0)): |
| 137 | + plt.subplot(3,3,i+1) |
| 138 | + g=X0[i] |
| 139 | + pos=nx.kamada_kawai_layout(g) |
| 140 | + nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) |
| 141 | +plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) |
| 142 | +plt.show() |
| 143 | + |
| 144 | + |
| 145 | + |
| 146 | +#%% |
| 147 | +# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph |
| 148 | +# Features distances are the euclidean distances |
| 149 | +Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] |
| 150 | +ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] |
| 151 | +Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] |
| 152 | +lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() |
| 153 | +sizebary=15 # we choose a barycenter with 15 nodes |
| 154 | + |
| 155 | +#%% |
| 156 | + |
| 157 | +A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) |
| 158 | + |
| 159 | +#%% |
| 160 | +bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) |
| 161 | +for i in range(len(A.ravel())): |
| 162 | + bary.add_node(i,attr_name=float(A.ravel()[i])) |
| 163 | + |
| 164 | +#%% |
| 165 | +pos = nx.kamada_kawai_layout(bary) |
| 166 | +nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) |
| 167 | +plt.suptitle('Barycenter',fontsize=20) |
| 168 | +plt.show() |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | + |
0 commit comments