Skip to content

Commit 9758519

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Copybara import of the project:
-- 2b68d08 by Otilia Stretcu <otiliastr@gmail.com>: Small fixes. -- 3265019 by Otilia Stretcu <otiliastr@gmail.com>: Changing model name. -- dca1457 by Otilia Stretcu <otiliastr@gmail.com>: Refactoring. -- 8a90308 by Otilia Stretcu <otiliastr@gmail.com>: Refactoring. -- eef1760 by Otilia Stretcu <otiliastr@gmail.com>: Removed unused function. -- 7954700 by Otilia Stretcu <otiliastr@gmail.com>: Rename run script. -- 25a71f2 by Otilia Stretcu <otiliastr@gmail.com>: Added support for Planetoid datasets, and training using the graph edges. -- 69aa91c by Otilia Stretcu <otiliastr@gmail.com>: Fix documentation and indentation. -- b43e5dc by Otilia Stretcu <otiliastr@gmail.com>: Changed model name. -- daf5d6b by Otilia Stretcu <otiliastr@gmail.com>: Small refactoring. -- dbfdcf5 by Otilia Stretcu <otiliastr@gmail.com>: Fix add_negative_edges_agr. -- 1fa5376 by Otilia Stretcu <otiliastr@gmail.com>: Added NGM agreement. -- 48ff042 by Otilia Stretcu <otiliastr@gmail.com>: Removed indices val from unlabeled in transductive setting. -- 1c46ee1 by Otilia Stretcu <otiliastr@gmail.com>: Removed indices val from self-labeling step. -- ddab3ee by Otilia Stretcu <otiliastr@gmail.com>: Fixing issue with no graph edges available between two labeled nodes. -- 2e8eff8 by Otilia Stretcu <otiliastr@gmail.com>: Fix edge iterator when none are available. PiperOrigin-RevId: 273611867
1 parent 8650926 commit 9758519

File tree

9 files changed

+1492
-419
lines changed

9 files changed

+1492
-419
lines changed

neural_structured_learning/research/gam/data/dataset.py

Lines changed: 269 additions & 111 deletions
Large diffs are not rendered by default.

neural_structured_learning/research/gam/data/loaders.py

Lines changed: 169 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@
1919

2020
import json
2121
import logging
22+
import os
2223
import pickle
24+
import sys
2325

24-
from gam.data.dataset import FixedDataset
26+
from gam.data.dataset import Dataset
27+
from gam.data.dataset import PlanetoidDataset
2528
from gam.data.preprocessing import convert_image
2629
from gam.data.preprocessing import split_train_val_unlabeled
2730

31+
import networkx as nx
2832
import numpy as np
33+
from scipy import sparse as sp
2934
import tensorflow_datasets as tfds
3035

3136

32-
def load_data_tf_datasets(
33-
dataset_name, target_num_train_per_class, target_num_val, seed):
37+
def load_data_tf_datasets(dataset_name, target_num_train_per_class,
38+
target_num_val, seed):
3439
"""Load and preprocess data from tensorflow_datasets."""
3540
logging.info('Loading and preprocessing data from tensorflow datasets...')
3641
# Load train data.
@@ -58,17 +63,24 @@ def load_data_tf_datasets(
5863
unlabeled_labels = data[5]
5964

6065
logging.info('Converting data to Dataset format...')
61-
data = FixedDataset(train_inputs, train_labels, val_inputs, val_labels,
62-
test_inputs, test_labels, unlabeled_inputs,
63-
unlabeled_labels, feature_preproc_fn=convert_image)
66+
data = Dataset.build_from_splits(
67+
name=dataset_name,
68+
inputs_train=train_inputs,
69+
labels_train=train_labels,
70+
inputs_val=val_inputs,
71+
labels_val=val_labels,
72+
inputs_test=test_inputs,
73+
labels_test=test_labels,
74+
inputs_unlabeled=unlabeled_inputs,
75+
labels_unlabeled=unlabeled_labels,
76+
feature_preproc_fn=convert_image)
6477
return data
6578

6679

6780
def load_data_realistic_ssl(dataset_name, data_path, label_map_path):
6881
"""Loads data from the `ealistic Evaluation of Deep SSL Algorithms`."""
6982
logging.info('Loading data from pickle at %s.', data_path)
70-
train_set, validation_set, test_set = pickle.load(
71-
open(data_path, 'rb'))
83+
train_set, validation_set, test_set = pickle.load(open(data_path, 'rb'))
7284
train_inputs = train_set['images']
7385
train_labels = train_set['labels']
7486
val_inputs = validation_set['images']
@@ -77,8 +89,9 @@ def load_data_realistic_ssl(dataset_name, data_path, label_map_path):
7789
test_labels = test_set['labels']
7890
# Load label map that specifies which trainining labeles are available.
7991
train_indices = json.load(open(label_map_path, 'r'))
80-
train_indices = [int(key.encode('ascii', 'ignore'))
81-
for key in train_indices['values']]
92+
train_indices = [
93+
int(key.encode('ascii', 'ignore')) for key in train_indices['values']
94+
]
8295
train_indices = np.asarray(train_indices)
8396

8497
# Select the loaded train indices, and make the rest unlabeled.
@@ -90,11 +103,152 @@ def load_data_realistic_ssl(dataset_name, data_path, label_map_path):
90103
train_labels = train_labels[train_indices]
91104

92105
# Select a feature preprocessing function, depending on the dataset.
93-
feature_preproc_fn = ((lambda image: image) if dataset_name == 'cifar10' else
94-
convert_image)
106+
feature_preproc_fn = ((lambda image: image)
107+
if dataset_name == 'cifar10' else convert_image)
95108

96-
data = FixedDataset(
97-
train_inputs, train_labels, val_inputs, val_labels, test_inputs,
98-
test_labels, unlabeled_inputs, unlabeled_labels,
109+
data = Dataset.build_from_splits(
110+
name=dataset_name,
111+
inputs_train=train_inputs,
112+
labels_train=train_labels,
113+
inputs_val=val_inputs,
114+
labels_val=val_labels,
115+
inputs_test=test_inputs,
116+
labels_test=test_labels,
117+
inputs_unlabeled=unlabeled_inputs,
118+
labels_unlabeled=unlabeled_labels,
99119
feature_preproc_fn=feature_preproc_fn)
100120
return data
121+
122+
123+
def load_from_planetoid_files(dataset_name, path):
124+
"""Loads Planetoid data in GCN format, as released with the GCN code.
125+
126+
This function is adapted from https://github.com/tkipf/gcn.
127+
128+
This function assumes that the following files can be found at the location
129+
specified by `path`:
130+
131+
ind.dataset_str.x => the feature vectors of the training instances
132+
as scipy.sparse.csr.csr_matrix object.
133+
ind.dataset_str.tx => the feature vectors of the test instances as
134+
scipy.sparse.csr.csr_matrix object.
135+
ind.dataset_str.allx => the feature vectors of both labeled and
136+
unlabeled training instances (a superset of
137+
ind.dataset_str.x) as
138+
scipy.sparse.csr.csr_matrix object.
139+
ind.dataset_str.y => the one-hot labels of the labeled training
140+
instances as numpy.ndarray object.
141+
ind.dataset_str.ty => the one-hot labels of the test instances as
142+
numpy.ndarray object.
143+
ind.dataset_str.ally => the labels for instances in
144+
ind.dataset_str.allx as numpy.ndarray object.
145+
ind.dataset_str.graph => a dict in the format
146+
{index: [index_of_neighbor_nodes]} as
147+
collections.defaultdict object.
148+
ind.dataset_str.test.index => the indices of test instances in graph, for
149+
the inductive setting as list object.
150+
151+
Args:
152+
dataset_name: A string representing the dataset name (e.g., `cora`).
153+
path: Path to the directory containing the files.
154+
155+
Returns:
156+
All data input files loaded (as well the training/test data).
157+
"""
158+
159+
def _sample_mask(idx, l):
160+
"""Create mask."""
161+
mask = np.zeros(l)
162+
mask[idx] = 1
163+
return np.array(mask, dtype=np.bool)
164+
165+
def _parse_index_file(filename):
166+
"""Parse index file."""
167+
index = []
168+
for line in open(filename):
169+
index.append(int(line.strip()))
170+
return index
171+
172+
def _load_file(name):
173+
"""Load from data file."""
174+
filename = 'ind.{}.{}'.format(dataset_name, name)
175+
filename = os.path.join(path, filename)
176+
with open(filename, 'rb') as f:
177+
if sys.version_info > (3, 0):
178+
return pickle.load(f, encoding='latin1') # pylint: disable=unexpected-keyword-arg
179+
else:
180+
return pickle.load(f)
181+
182+
x = _load_file('x')
183+
y = _load_file('y')
184+
tx = _load_file('tx')
185+
ty = _load_file('ty')
186+
allx = _load_file('allx')
187+
ally = _load_file('ally')
188+
graph = _load_file('graph')
189+
190+
filename = 'ind.{}.test.index'.format(dataset_name)
191+
filename = os.path.join(path, filename)
192+
test_idx_reorder = _parse_index_file(filename)
193+
test_idx_range = np.sort(test_idx_reorder)
194+
195+
if dataset_name == 'citeseer':
196+
# Fix citeseer dataset (there are some isolated nodes in the graph).
197+
# Find isolated nodes, add them as zero-vecs into the right position.
198+
test_idx_range_full = range(
199+
min(test_idx_reorder),
200+
max(test_idx_reorder) + 1)
201+
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
202+
tx_extended[test_idx_range - min(test_idx_range), :] = tx
203+
tx = tx_extended
204+
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
205+
ty_extended[test_idx_range - min(test_idx_range), :] = ty
206+
ty = ty_extended
207+
208+
features = sp.vstack((allx, tx)).tolil()
209+
features[test_idx_reorder, :] = features[test_idx_range, :]
210+
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
211+
212+
labels = np.vstack((ally, ty))
213+
labels[test_idx_reorder, :] = labels[test_idx_range, :]
214+
215+
idx_test = test_idx_range.tolist()
216+
idx_train = range(len(y))
217+
idx_val = range(len(y), len(y) + 500)
218+
219+
train_mask = _sample_mask(idx_train, labels.shape[0])
220+
val_mask = _sample_mask(idx_val, labels.shape[0])
221+
test_mask = _sample_mask(idx_test, labels.shape[0])
222+
223+
y_train = np.zeros(labels.shape)
224+
y_val = np.zeros(labels.shape)
225+
y_test = np.zeros(labels.shape)
226+
y_train[train_mask, :] = labels[train_mask, :]
227+
y_val[val_mask, :] = labels[val_mask, :]
228+
y_test[test_mask, :] = labels[test_mask, :]
229+
230+
return (adj, features, y_train, y_val, y_test, train_mask, val_mask,
231+
test_mask, labels)
232+
233+
234+
def load_data_planetoid(name, path, splits_path=None, row_normalize=False):
235+
"""Load Planetoid data."""
236+
if splits_path is None:
237+
# Load from file in Planetoid format.
238+
(adj, features, _, _, _, train_mask, val_mask, test_mask,
239+
labels) = load_from_planetoid_files(name, path)
240+
else:
241+
# Otherwise load from a path where we saved a pickle with random splits.
242+
logging.info('Loading from splits path: %s', splits_path)
243+
(adj, features, _, _, _, train_mask, val_mask, test_mask,
244+
labels) = pickle.load(open(splits_path, 'rb'))
245+
246+
return PlanetoidDataset(
247+
name,
248+
adj,
249+
features,
250+
train_mask,
251+
val_mask,
252+
test_mask,
253+
labels,
254+
row_normalize=row_normalize)

neural_structured_learning/research/gam/data/preprocessing.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,36 @@ def convert_image(image):
2929
return image
3030

3131

32-
def split_train_val_unlabeled(train_inputs, train_labels,
33-
target_num_train_per_class, target_num_val,
32+
def split_train_val(indices, ratio_val, rng, max_num_val=None):
33+
"""Split the train sample indices into train and validation.
34+
35+
Args:
36+
indices: A numpy array containing the indices of the training samples.
37+
ratio_val: A float number between (0, 1) representing the ratio of samples
38+
to use for validation.
39+
rng: A random number generator.
40+
max_num_val: An integer representing the maximum number of samples to
41+
include in the validation set.
42+
43+
Returns:
44+
Two numpy arrays containing the subset of indices used for training, and
45+
validation, respectively.
46+
"""
47+
num_samples = indices.shape[0]
48+
num_val = int(ratio_val * num_samples)
49+
if max_num_val and num_val > max_num_val:
50+
num_val = max_num_val
51+
ind = np.arange(0, num_samples)
52+
rng.shuffle(ind)
53+
ind_val = ind[:num_val]
54+
ind_train = ind[num_val:]
55+
return ind_train, ind_val
56+
57+
58+
def split_train_val_unlabeled(train_inputs,
59+
train_labels,
60+
target_num_train_per_class,
61+
target_num_val,
3462
seed=None):
3563
"""Splits the training data into train, validation and unlabeled samples.
3664
@@ -102,5 +130,5 @@ def split_train_val_unlabeled(train_inputs, train_labels,
102130
train_inputs = train_inputs[ind_train]
103131
train_labels = train_labels[ind_train]
104132

105-
return (train_inputs, train_labels, val_inputs, val_labels,
106-
unlabeled_inputs, unlabeled_labels)
133+
return (train_inputs, train_labels, val_inputs, val_labels, unlabeled_inputs,
134+
unlabeled_labels)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Helper functions for GAMs."""
15+
from gam.models.cnn import ImageCNNAgreement
16+
from gam.models.mlp import MLP
17+
from gam.models.wide_resnet import WideResnet
18+
19+
import tensorflow as tf
20+
21+
22+
def parse_layers_string(layers_string):
23+
"""Convert a layer size string (e.g., `128_64_32`) to a list of integers."""
24+
if not layers_string:
25+
return ()
26+
num_hidden = layers_string.split('_')
27+
num_hidden = [int(num) for num in num_hidden]
28+
return num_hidden
29+
30+
31+
def get_model_cls(model_name, data, dataset_name, hidden=None, **unused_kwargs):
32+
"""Picks the models depending on the provided configuration flags."""
33+
# Create model classification.
34+
if model_name == 'mlp':
35+
hidden = parse_layers_string(hidden) if hidden is not None else ()
36+
return MLP(
37+
output_dim=data.num_classes,
38+
hidden_sizes=hidden,
39+
activation=tf.nn.leaky_relu,
40+
name='mlp_cls')
41+
elif model_name == 'cnn':
42+
if dataset_name in ('mnist', 'fashion_mnist'):
43+
channels = 1
44+
elif dataset_name in ('cifar10', 'cifar100', 'svhn_cropped', 'svhn'):
45+
channels = 3
46+
else:
47+
raise ValueError('Dataset name `%s` unsupported.' % dataset_name)
48+
return ImageCNNAgreement(
49+
output_dim=data.num_classes,
50+
channels=channels,
51+
activation=tf.nn.leaky_relu,
52+
name='cnn_cls')
53+
elif model_name == 'wide_resnet':
54+
return WideResnet(
55+
num_classes=data.num_classes,
56+
lrelu_leakiness=0.1,
57+
horizontal_flip=dataset_name in ('cifar10',),
58+
random_translation=True,
59+
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
60+
width=2,
61+
num_residual_units=4,
62+
name='wide_resnet_cls')
63+
else:
64+
raise NotImplementedError()
65+
66+
67+
def get_model_agr(model_name,
68+
dataset_name,
69+
hidden_aggreg=None,
70+
aggregation_agr_inputs='dist',
71+
hidden=None,
72+
**unused_kwargs):
73+
"""Create agreement model."""
74+
hidden = parse_layers_string(hidden) if hidden is not None else ()
75+
hidden_aggreg = (
76+
parse_layers_string(hidden_aggreg) if hidden_aggreg is not None else ())
77+
if model_name == 'mlp':
78+
return MLP(
79+
output_dim=1,
80+
hidden_sizes=hidden,
81+
activation=tf.nn.leaky_relu,
82+
aggregation=aggregation_agr_inputs,
83+
hidden_aggregation=hidden_aggreg,
84+
is_binary_classification=True,
85+
name='mlp_agr')
86+
elif model_name == 'cnn':
87+
if dataset_name in ('mnist', 'fashion_mnist'):
88+
channels = 1
89+
elif dataset_name in ('cifar10', 'cifar100', 'svhn_cropped', 'svhn'):
90+
channels = 3
91+
else:
92+
raise ValueError('Dataset name `%s` unsupported.' % dataset_name)
93+
return ImageCNNAgreement(
94+
output_dim=1,
95+
channels=channels,
96+
activation=tf.nn.leaky_relu,
97+
aggregation=aggregation_agr_inputs,
98+
hidden_aggregation=hidden_aggreg,
99+
is_binary_classification=True,
100+
name='cnn_agr')
101+
elif model_name == 'wide_resnet':
102+
return WideResnet(
103+
num_classes=1,
104+
lrelu_leakiness=0.1,
105+
horizontal_flip=dataset_name in ('cifar10',),
106+
random_translation=True,
107+
gaussian_noise=dataset_name not in ('svhn', 'svhn_cropped'),
108+
width=2,
109+
num_residual_units=4,
110+
name='wide_resnet_cls',
111+
is_binary_classification=True,
112+
aggregation=aggregation_agr_inputs,
113+
activation=tf.nn.leaky_relu,
114+
hidden_aggregation=hidden_aggreg)
115+
else:
116+
raise NotImplementedError()

0 commit comments

Comments
 (0)