Skip to content

Commit fa98906

Browse files
committed
Reame +pep8
1 parent 63bbeb3 commit fa98906

File tree

5 files changed

+190
-174
lines changed

5 files changed

+190
-174
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,17 @@ You can also post bug reports and feature requests in Github issues. Make sure t
222222
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
223223

224224
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
225+
226+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016).
227+
228+
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229+
230+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
231+
232+
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
233+
234+
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
235+
236+
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
237+
238+
[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).

examples/plot_barycenter_fgw.py

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
from ot.gromov import fgw_barycenters
3131
#%% Graph functions
3232

33-
def find_thresh(C,inf=0.5,sup=3,step=10):
33+
34+
def find_thresh(C, inf=0.5, sup=3, step=10):
3435
""" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
35-
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
36-
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
36+
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
37+
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
3738
and the original matrix.
3839
Parameters
3940
----------
@@ -43,21 +44,22 @@ def find_thresh(C,inf=0.5,sup=3,step=10):
4344
The beginning of the linesearch
4445
sup : float
4546
The end of the linesearch
46-
step : integer
47-
Number of thresholds tested
47+
step : integer
48+
Number of thresholds tested
4849
"""
49-
dist=[]
50-
search=np.linspace(inf,sup,step)
50+
dist = []
51+
search = np.linspace(inf, sup, step)
5152
for thresh in search:
52-
Cprime=sp_to_adjency(C,0,thresh)
53-
SC=shortest_path(Cprime,method='D')
54-
SC[SC==float('inf')]=100
55-
dist.append(np.linalg.norm(SC-C))
56-
return search[np.argmin(dist)],dist
57-
58-
def sp_to_adjency(C,threshinf=0.2,threshsup=1.8):
59-
""" Thresholds the structure matrix in order to compute an adjency matrix.
60-
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
53+
Cprime = sp_to_adjency(C, 0, thresh)
54+
SC = shortest_path(Cprime, method='D')
55+
SC[SC == float('inf')] = 100
56+
dist.append(np.linalg.norm(SC - C))
57+
return search[np.argmin(dist)], dist
58+
59+
60+
def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
61+
""" Thresholds the structure matrix in order to compute an adjency matrix.
62+
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
6163
Parameters
6264
----------
6365
C : ndarray, shape (n_nodes,n_nodes)
@@ -71,102 +73,100 @@ def sp_to_adjency(C,threshinf=0.2,threshsup=1.8):
7173
C : ndarray, shape (n_nodes,n_nodes)
7274
The threshold matrix. Each element is in {0,1}
7375
"""
74-
H=np.zeros_like(C)
75-
np.fill_diagonal(H,np.diagonal(C))
76-
C=C-H
77-
C=np.minimum(np.maximum(C,threshinf),threshsup)
78-
C[C==threshsup]=0
79-
C[C!=0]=1
80-
81-
return C
82-
83-
def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None):
76+
H = np.zeros_like(C)
77+
np.fill_diagonal(H, np.diagonal(C))
78+
C = C - H
79+
C = np.minimum(np.maximum(C, threshinf), threshsup)
80+
C[C == threshsup] = 0
81+
C[C != 0] = 1
82+
83+
return C
84+
85+
86+
def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
8487
""" Create a noisy circular graph
8588
"""
86-
g=nx.Graph()
89+
g = nx.Graph()
8790
g.add_nodes_from(list(range(N)))
8891
for i in range(N):
89-
noise=float(np.random.normal(mu,sigma,1))
92+
noise = float(np.random.normal(mu, sigma, 1))
9093
if with_noise:
91-
g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise)
94+
g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
9295
else:
93-
g.add_node(i,attr_name=math.sin(2*i*math.pi/N))
94-
g.add_edge(i,i+1)
96+
g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
97+
g.add_edge(i, i + 1)
9598
if structure_noise:
96-
randomint=np.random.randint(0,p)
97-
if randomint==0:
98-
if i<=N-3:
99-
g.add_edge(i,i+2)
100-
if i==N-2:
101-
g.add_edge(i,0)
102-
if i==N-1:
103-
g.add_edge(i,1)
104-
g.add_edge(N,0)
105-
noise=float(np.random.normal(mu,sigma,1))
99+
randomint = np.random.randint(0, p)
100+
if randomint == 0:
101+
if i <= N - 3:
102+
g.add_edge(i, i + 2)
103+
if i == N - 2:
104+
g.add_edge(i, 0)
105+
if i == N - 1:
106+
g.add_edge(i, 1)
107+
g.add_edge(N, 0)
108+
noise = float(np.random.normal(mu, sigma, 1))
106109
if with_noise:
107-
g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise)
110+
g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
108111
else:
109-
g.add_node(N,attr_name=math.sin(2*N*math.pi/N))
112+
g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
110113
return g
111114

112-
def graph_colors(nx_graph,vmin=0,vmax=7):
113-
cnorm = mcol.Normalize(vmin=vmin,vmax=vmax)
114-
cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis')
115+
116+
def graph_colors(nx_graph, vmin=0, vmax=7):
117+
cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
118+
cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
115119
cpick.set_array([])
116120
val_map = {}
117-
for k,v in nx.get_node_attributes(nx_graph,'attr_name').items():
118-
val_map[k]=cpick.to_rgba(v)
119-
colors=[]
121+
for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
122+
val_map[k] = cpick.to_rgba(v)
123+
colors = []
120124
for node in nx_graph.nodes():
121125
colors.append(val_map[node])
122126
return colors
123-
127+
124128
#%% create dataset
125129
# We build a dataset of noisy circular graphs.
126130
# Noise is added on the structures by random connections and on the features by gaussian noise.
127131

132+
128133
np.random.seed(30)
129-
X0=[]
134+
X0 = []
130135
for k in range(9):
131-
X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3))
132-
136+
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
137+
133138
#%% Plot dataset
134139

135-
plt.figure(figsize=(8,10))
140+
plt.figure(figsize=(8, 10))
136141
for i in range(len(X0)):
137-
plt.subplot(3,3,i+1)
138-
g=X0[i]
139-
pos=nx.kamada_kawai_layout(g)
140-
nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100)
141-
plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20)
142+
plt.subplot(3, 3, i + 1)
143+
g = X0[i]
144+
pos = nx.kamada_kawai_layout(g)
145+
nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
146+
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
142147
plt.show()
143148

144149

145-
146150
#%%
147151
# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
148152
# Features distances are the euclidean distances
149-
Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0]
150-
ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0]
151-
Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0]
152-
lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel()
153-
sizebary=15 # we choose a barycenter with 15 nodes
153+
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
154+
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
155+
Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
156+
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
157+
sizebary = 15 # we choose a barycenter with 15 nodes
154158

155159
#%%
156160

157-
A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95)
161+
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
158162

159163
#%%
160-
bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0]))
164+
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
161165
for i in range(len(A.ravel())):
162-
bary.add_node(i,attr_name=float(A.ravel()[i]))
163-
166+
bary.add_node(i, attr_name=float(A.ravel()[i]))
167+
164168
#%%
165169
pos = nx.kamada_kawai_layout(bary)
166-
nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False)
167-
plt.suptitle('Barycenter',fontsize=20)
170+
nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
171+
plt.suptitle('Barycenter', fontsize=20)
168172
plt.show()
169-
170-
171-
172-

0 commit comments

Comments
 (0)