Skip to content

Commit 25a71f2

Browse files
committed
Added support for Planetoid datasets, and training using the graph edges.
1 parent 7954700 commit 25a71f2

File tree

9 files changed

+1147
-243
lines changed

9 files changed

+1147
-243
lines changed

neural_structured_learning/research/gam/data/dataset.py

Lines changed: 204 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -12,107 +12,106 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Data containers for Graph Agreement Models."""
15-
import abc
1615
import collections
1716
import logging
1817
import os
1918
import pickle
19+
import scipy
2020

2121
import numpy as np
2222
import tensorflow as tf
2323

24-
25-
class Dataset(object):
26-
"""Interface for different types of datasets."""
27-
__metaclass__ = abc.ABCMeta
28-
29-
@abc.abstractmethod
30-
def num_train(self):
31-
pass
32-
33-
@abc.abstractmethod
34-
def num_val(self):
35-
pass
36-
37-
@abc.abstractmethod
38-
def num_test(self):
39-
pass
40-
41-
@abc.abstractmethod
42-
def num_unlabeled(self):
43-
pass
44-
45-
@abc.abstractmethod
46-
def copy_labels(self):
47-
pass
48-
49-
def save_to_pickle(self, file_path):
50-
pickle.dump(self, open(file_path, 'w'))
51-
52-
@staticmethod
53-
def load_from_pickle(file_path):
54-
dataset = pickle.load(open(file_path, 'r'))
55-
return dataset
24+
from gam.data.preprocessing import split_train_val
5625

5726

58-
class FixedDataset(Dataset):
59-
"""A dataset containing features of fixed size.
27+
class Dataset(object):
28+
"""A container for datasets.
6029
6130
In this dataset, each sample has the same number of features.
6231
This class manages different splits of the data for train, validation, test
6332
and unlabeled. These sets of samples are disjoint.
6433
"""
6534

66-
def __init__(self,
67-
x_train,
68-
y_train,
69-
x_val,
70-
y_val,
71-
x_test,
72-
y_test,
73-
x_unlabeled,
74-
y_unlabeled=None,
75-
num_classes=None,
35+
def __init__(self, name, features, labels, indices_train, indices_test,
36+
indices_val, indices_unlabeled, num_classes=None,
7637
feature_preproc_fn=lambda x: x):
77-
n_train = x_train.shape[0]
78-
n_val = x_val.shape[0]
79-
n_test = x_test.shape[0]
80-
n_unlabeled = x_unlabeled.shape[0]
38+
self.name = name
39+
self.features = features
40+
self.labels = labels
41+
42+
self.indices_train = indices_train
43+
self.indices_val = indices_val
44+
self.indices_test = indices_test
45+
self.indices_unlabeled = indices_unlabeled
46+
self.feature_preproc_fn = feature_preproc_fn
8147

82-
if y_unlabeled is None:
83-
y_unlabeled = np.zeros(shape=(n_unlabeled,), dtype=y_train.dtype)
48+
self.num_val = self.indices_val.shape[0]
49+
self.num_test = self.indices_test.shape[0]
8450

85-
# Concatenate samples.
86-
self.features = np.concatenate((x_train, x_val, x_unlabeled, x_test))
87-
self.labels = np.concatenate((y_train, y_val, y_unlabeled, y_test))
51+
self.num_samples = labels.shape[0]
52+
self.features_shape = features.shape[1:]
53+
self.num_features = np.prod(features.shape[1:])
54+
self.num_classes = 1 + max(labels) if num_classes is None else num_classes
8855

89-
self._num_features = np.prod(self.features.shape[1:])
90-
self._num_classes = 1 + max(self.labels) if num_classes is None else \
91-
num_classes
92-
self._num_samples = n_train + n_val + n_unlabeled + n_test
56+
@staticmethod
57+
def build_from_splits(name, inputs_train, labels_train, inputs_val,
58+
labels_val, inputs_test, labels_test, inputs_unlabeled,
59+
labels_unlabeled=None, num_classes=None,
60+
feature_preproc_fn=lambda x: x):
61+
num_train = inputs_train.shape[0]
62+
num_val = inputs_val.shape[0]
63+
num_unlabeled = inputs_unlabeled.shape[0]
64+
num_test = inputs_test.shape[0]
65+
66+
if labels_unlabeled is None:
67+
labels_unlabeled = np.zeros(shape=(num_unlabeled,),
68+
dtype=labels_train[0].dtype)
69+
features = np.concatenate(
70+
(inputs_train, inputs_val, inputs_unlabeled, inputs_test))
71+
labels = np.concatenate(
72+
(labels_train, labels_val, labels_unlabeled, labels_test))
73+
74+
indices_train = np.arange(num_train)
75+
indices_val = np.arange(num_train, num_train+num_val)
76+
indices_unlabeled = np.arange(num_train+num_val,
77+
num_train+num_val+num_unlabeled)
78+
indices_test = np.arange(num_train+num_val+num_unlabeled,
79+
num_train+num_val+num_unlabeled+num_test)
80+
81+
return Dataset(name=name,
82+
features=features,
83+
labels=labels,
84+
indices_train=indices_train,
85+
indices_test=indices_test,
86+
indices_val=indices_val,
87+
indices_unlabeled=indices_unlabeled,
88+
num_classes=num_classes,
89+
feature_preproc_fn=feature_preproc_fn)
9390

94-
self.indices_train = np.arange(n_train)
95-
self.indices_val = np.arange(n_train, n_train+n_val)
96-
self.indices_unlabeled = np.arange(n_train+n_val, n_train+n_val+n_unlabeled)
97-
self.indices_test = np.arange(n_train+n_val+n_unlabeled, self._num_samples)
91+
@staticmethod
92+
def build_from_features(name, features, labels, indices_train, indices_test,
93+
indices_val=None, indices_unlabeled=None,
94+
percent_val=0.2, seed=None, num_classes=None,
95+
feature_preproc_fn=lambda x: x):
96+
if indices_val is None:
97+
rng = np.random.RandomState(seed=seed)
98+
indices_train, indices_val = split_train_val(
99+
np.arange(indices_train.shape[0]), percent_val, rng)
100+
101+
return Dataset(name=name,
102+
features=features,
103+
labels=labels,
104+
indices_train=indices_train,
105+
indices_test=indices_test,
106+
indices_val=indices_val,
107+
indices_unlabeled=indices_unlabeled,
108+
num_classes=num_classes,
109+
feature_preproc_fn=feature_preproc_fn)
98110

99-
self.feature_preproc_fn = feature_preproc_fn
100111

101112
def copy_labels(self):
102113
return np.copy(self.labels)
103114

104-
def update_labels(self, indices_samples, new_labels):
105-
"""Updates the labels of the samples with the provided indices.
106-
107-
Arguments:
108-
indices_samples: A list of integers representing sample indices.
109-
new_labels: A list of integers representing the new labels of th samples
110-
in indices_samples.
111-
"""
112-
indices_samples = np.asarray(indices_samples)
113-
new_labels = np.asarray(new_labels)
114-
self.labels[indices_samples] = new_labels
115-
116115
def get_features(self, indices):
117116
"""Returns the features of the samples with the provided indices."""
118117
f = self.features[indices]
@@ -122,23 +121,6 @@ def get_features(self, indices):
122121
def get_labels(self, indices):
123122
return self.labels[indices]
124123

125-
@property
126-
def num_features(self):
127-
return self._num_features
128-
129-
@property
130-
def features_shape(self):
131-
"""Returns the shape of the input features, not including batch size."""
132-
return self.features.shape[1:]
133-
134-
@property
135-
def num_classes(self):
136-
return self._num_classes
137-
138-
@property
139-
def num_samples(self):
140-
return self._num_samples
141-
142124
def num_train(self):
143125
return self.indices_train.shape[0]
144126

@@ -181,13 +163,134 @@ def label_samples(self, indices_samples, new_labels):
181163
# indices, without checking if they already exist.
182164
self.indices_train = np.concatenate((self.indices_train, indices_samples),
183165
axis=0)
184-
# Remove the recently labeled samples from the unlabeled set.
166+
# Remove the recently labeled samples from the unlabeled set.
185167
indices_samples = set(indices_samples)
186168
self.indices_unlabeled = np.asarray(
187169
[u for u in self.indices_unlabeled if u not in indices_samples])
188170

171+
def update_labels(self, indices, new_labels):
172+
"""Updates the labels of the samples with the provided indices.
173+
174+
Arguments:
175+
indices: A list of integers representing sample indices.
176+
new_labels: A list of integers representing the new labels of th samples
177+
in indices_samples.
178+
"""
179+
indices = np.asarray(indices)
180+
new_labels = np.asarray(new_labels)
181+
self.labels[indices] = new_labels
182+
183+
def save_to_pickle(self, file_path):
184+
pickle.dump(self, open(file_path, 'w'))
189185

190-
class CotrainDataset(Dataset):
186+
@staticmethod
187+
def load_from_pickle(file_path):
188+
dataset = pickle.load(open(file_path, 'r'))
189+
return dataset
190+
191+
192+
class GraphDataset(Dataset):
193+
"""Data container for SSL datasets."""
194+
class Edge(object):
195+
def __init__(self, src, tgt, weight=None):
196+
self.src = src
197+
self.tgt = tgt
198+
self.weight = weight
199+
200+
def __init__(self, name, features, labels, edges, indices_train, indices_test,
201+
indices_val=None, indices_unlabeled=None, percent_val=0.2,
202+
seed=None, num_classes=None, feature_preproc_fn=lambda x: x):
203+
self.edges = edges
204+
205+
if indices_val is None:
206+
rng = np.random.RandomState(seed=seed)
207+
indices_train, indices_val = split_train_val(
208+
np.arange(indices_train.shape[0]), percent_val, rng)
209+
210+
super().__init__(
211+
name=name,
212+
features=features,
213+
labels=labels,
214+
indices_train=indices_train,
215+
indices_test=indices_test,
216+
indices_val=indices_val,
217+
indices_unlabeled=indices_unlabeled,
218+
num_classes=num_classes,
219+
feature_preproc_fn=feature_preproc_fn)
220+
221+
def get_edges(self, src_labeled=None, tgt_labeled=None,
222+
label_must_match=False):
223+
labeled_mask = np.full((self.num_samples,), False)
224+
labeled_mask[self.get_indices_train()] = True
225+
226+
def _labeled_cond(idx, is_labeled):
227+
return (is_labeled is None) or (is_labeled == labeled_mask[idx])
228+
229+
def _agreement_cond(edge):
230+
return self.get_labels(edge.src) == self.get_labels(edge.tgt)
231+
232+
agreement_cond = _agreement_cond if label_must_match else lambda e: True
233+
234+
return [e for e in self.edges
235+
if _labeled_cond(e.src, src_labeled) and \
236+
_labeled_cond(e.tgt, tgt_labeled) and \
237+
agreement_cond(e)]
238+
239+
240+
class PlanetoidDataset(GraphDataset):
241+
"""Data container for Planetoid datasets."""
242+
243+
def __init__(self, name, adj, features, train_mask, val_mask, test_mask,
244+
labels, row_normalize=False):
245+
246+
# Extract train, val, test, unlabeled indices.
247+
train_indices = np.where(train_mask)[0]
248+
test_indices = np.where(test_mask)[0]
249+
val_indices = np.where(val_mask)[0]
250+
unlabeled_mask = np.logical_not(train_mask | test_mask | val_mask)
251+
unlabeled_indices = np.where(unlabeled_mask)[0]
252+
253+
# Extract node features.
254+
if row_normalize:
255+
features = self.preprocess_features(features)
256+
else:
257+
features = features.todense()
258+
features = np.float32(features)
259+
260+
# Extract labels.
261+
labels = np.argmax(labels, axis=-1)
262+
num_classes = max(labels) + 1
263+
264+
# Extract edges.
265+
adj = scipy.sparse.coo_matrix(adj)
266+
edges = [self.Edge(src, tgt, val)
267+
for src, tgt, val in zip(adj.row, adj.col, adj.data)]
268+
269+
# Convert to Dataset format.
270+
super().__init__(
271+
name=name,
272+
features=features,
273+
labels=labels,
274+
edges=edges,
275+
indices_train=train_indices,
276+
indices_test=test_indices,
277+
indices_val=val_indices,
278+
indices_unlabeled=unlabeled_indices,
279+
num_classes=num_classes,
280+
feature_preproc_fn=lambda x: x)
281+
282+
@staticmethod
283+
def preprocess_features(features):
284+
"""Row-normalize feature matrix."""
285+
rowsum = np.array(features.sum(1))
286+
r_inv = np.power(rowsum, -1).flatten()
287+
r_inv[np.isinf(r_inv)] = 0.
288+
r_mat_inv = scipy.sparse.diags(r_inv)
289+
features = r_mat_inv.dot(features)
290+
return features.todense()
291+
292+
293+
class CotrainDataset(object):
191294
"""A wrapper around a Dataset object, adding co-training functionality.
192295
193296
Attributes:
@@ -430,3 +533,10 @@ def restore_state_from_file(self, path):
430533

431534
def copy_labels(self):
432535
return self.dataset.copy_labels()
536+
537+
def get_edges(self, src_labeled=None, tgt_labeled=None,
538+
label_must_match=False):
539+
return self.dataset.get_edges(
540+
src_labeled=src_labeled,
541+
tgt_labeled=tgt_labeled,
542+
label_must_match=label_must_match)

0 commit comments

Comments
 (0)