Skip to content

Commit 69aa91c

Browse files
committed
Fix documentation and indentation.
1 parent 25a71f2 commit 69aa91c

File tree

1 file changed

+46
-32
lines changed
  • neural_structured_learning/research/gam/data

1 file changed

+46
-32
lines changed

neural_structured_learning/research/gam/data/loaders.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,39 @@ def load_data_realistic_ssl(dataset_name, data_path, label_map_path):
122122
return data
123123

124124

125-
def load_from_planetoid_files(dataset_str, path):
126-
"""Loads input data from gcn/data directory.
127-
128-
This function is copied and adapted from https://github.com/tkipf/gcn/blob/master/gcn/utils.py.
129-
130-
ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
131-
ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
132-
ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
133-
(a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
134-
ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
135-
ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
136-
ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
137-
ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
138-
object;
139-
ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
140-
141-
All objects above must be saved using python pickle module.
142-
143-
:param dataset_str: Dataset name
144-
:return: All data input files loaded (as well the training/test data).
125+
def load_from_planetoid_files(dataset_name, path):
126+
"""Loads Planetoid data in GCN format, as released with the GCN code.
127+
128+
This function is adapted from https://github.com/tkipf/gcn.
129+
130+
This function assumes that the following files can be found at the location
131+
specified by `path`:
132+
ind.dataset_str.x => the feature vectors of the training instances
133+
as scipy.sparse.csr.csr_matrix object.
134+
ind.dataset_str.tx => the feature vectors of the test instances as
135+
scipy.sparse.csr.csr_matrix object.
136+
ind.dataset_str.allx => the feature vectors of both labeled and
137+
unlabeled training instances (a superset of
138+
ind.dataset_str.x) as
139+
scipy.sparse.csr.csr_matrix object.
140+
ind.dataset_str.y => the one-hot labels of the labeled training
141+
instances as numpy.ndarray object.
142+
ind.dataset_str.ty => the one-hot labels of the test instances as
143+
numpy.ndarray object.
144+
ind.dataset_str.ally => the labels for instances in
145+
ind.dataset_str.allx as numpy.ndarray object.
146+
ind.dataset_str.graph => a dict in the format
147+
{index: [index_of_neighbor_nodes]} as
148+
collections.defaultdict object.
149+
ind.dataset_str.test.index => the indices of test instances in graph, for
150+
the inductive setting as list object.
151+
152+
Arguments:
153+
dataset_name: A string representing the dataset name (e.g., `cora`).
154+
path: Path to the directory containing the files.
155+
156+
Returns:
157+
All data input files loaded (as well the training/test data).
145158
"""
146159
def _sample_mask(idx, l):
147160
"""Create mask."""
@@ -159,7 +172,7 @@ def _parse_index_file(filename):
159172
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
160173
objects = []
161174
for i in range(len(names)):
162-
filename = "ind.{}.{}".format(dataset_str, names[i])
175+
filename = "ind.{}.{}".format(dataset_name, names[i])
163176
filename = os.path.join(path, filename)
164177
with open(filename, 'rb') as f:
165178
if sys.version_info > (3, 0):
@@ -168,15 +181,16 @@ def _parse_index_file(filename):
168181
objects.append(pkl.load(f))
169182

170183
x, y, tx, ty, allx, ally, graph = tuple(objects)
171-
filename = "ind.{}.test.index".format(dataset_str)
184+
filename = "ind.{}.test.index".format(dataset_name)
172185
filename = os.path.join(path, filename)
173186
test_idx_reorder = _parse_index_file(filename)
174187
test_idx_range = np.sort(test_idx_reorder)
175188

176-
if dataset_str == 'citeseer':
189+
if dataset_name == 'citeseer':
177190
# Fix citeseer dataset (there are some isolated nodes in the graph).
178191
# Find isolated nodes, add them as zero-vecs into the right position.
179-
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
192+
test_idx_range_full = range(min(test_idx_reorder),
193+
max(test_idx_reorder)+1)
180194
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
181195
tx_extended[test_idx_range-min(test_idx_range), :] = tx
182196
tx = tx_extended
@@ -187,7 +201,6 @@ def _parse_index_file(filename):
187201
features = sp.vstack((allx, tx)).tolil()
188202
features[test_idx_reorder, :] = features[test_idx_range, :]
189203
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
190-
#adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph, create_using=nx.DiGraph))
191204

192205
labels = np.vstack((ally, ty))
193206
labels[test_idx_reorder, :] = labels[test_idx_range, :]
@@ -207,19 +220,20 @@ def _parse_index_file(filename):
207220
y_val[val_mask, :] = labels[val_mask, :]
208221
y_test[test_mask, :] = labels[test_mask, :]
209222

210-
return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, labels
223+
return (adj, features, y_train, y_val, y_test, train_mask, val_mask,
224+
test_mask, labels)
211225

212226

213227
def load_data_planetoid(name, path, splits_path=None, row_normalize=False):
214-
# Load from file.
215228
if splits_path is None:
216-
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, \
217-
labels = load_from_planetoid_files(name, path)
229+
# Load from file in Planetoid format.
230+
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask,\
231+
labels = load_from_planetoid_files(name, path)
218232
else:
219-
# Otherwise load from splits path.
233+
# Otherwise load from a path where we saved a pickle with random splits.
220234
logging.info('Loading from splits path: %s', splits_path)
221-
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, \
222-
labels = pickle.load(open(splits_path, "rb"))
235+
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask,\
236+
labels = pickle.load(open(splits_path, "rb"))
223237

224238
return PlanetoidDataset(name, adj, features, train_mask, val_mask, test_mask,
225239
labels, row_normalize=row_normalize)

0 commit comments

Comments
 (0)