Skip to content

Commit ae040c6

Browse files
arjungtensorflow-copybara
authored andcommitted
Update example trainer binaries not to parse neighbor features in other modes besides training.
PiperOrigin-RevId: 285821272
1 parent 989c055 commit ae040c6

File tree

1 file changed

+30
-27
lines changed

1 file changed

+30
-27
lines changed

neural_structured_learning/examples/graph_keras_mlp_cora.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class HParams(object):
6363
### model architecture
6464
num_fc_units = attr.ib(default=[50, 50])
6565
### training parameters
66-
train_epochs = attr.ib(default=100)
66+
train_epochs = attr.ib(default=10)
6767
batch_size = attr.ib(default=128)
6868
dropout_rate = attr.ib(default=0.5)
6969
### eval parameters
@@ -93,8 +93,8 @@ def load_dataset(filename):
9393
return tf.data.TFRecordDataset([filename])
9494

9595

96-
def make_datasets(train_data_path, test_data_path, hparams):
97-
"""Returns training and test data as a pair of `tf.data.Dataset` instances."""
96+
def make_dataset(file_path, training, include_nbr_features, hparams):
97+
"""Returns a `tf.data.Dataset` instance based on data in `file_path`."""
9898

9999
def parse_example(example_proto):
100100
"""Extracts relevant fields from the `example_proto`.
@@ -120,35 +120,33 @@ def parse_example(example_proto):
120120
'label':
121121
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
122122
}
123-
for i in range(hparams.num_neighbors):
124-
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
125-
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)
126-
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
127-
[hparams.max_seq_length],
128-
tf.int64,
129-
default_value=tf.constant(
130-
0, dtype=tf.int64, shape=[hparams.max_seq_length]))
131-
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
132-
[1], tf.float32, default_value=tf.constant([0.0]))
123+
if include_nbr_features:
124+
for i in range(hparams.num_neighbors):
125+
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
126+
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
127+
NBR_WEIGHT_SUFFIX)
128+
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
129+
[hparams.max_seq_length],
130+
tf.int64,
131+
default_value=tf.constant(
132+
0, dtype=tf.int64, shape=[hparams.max_seq_length]))
133+
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
134+
[1], tf.float32, default_value=tf.constant([0.0]))
133135

134136
features = tf.io.parse_single_example(example_proto, feature_spec)
135137

136138
labels = features.pop('label')
137139
return features, labels
138140

139-
def make_dataset(file_path, training=False):
140-
"""Creates a `tf.data.Dataset`."""
141-
# If the dataset is sharded, the following code may be required:
142-
# filenames = tf.data.Dataset.list_files(file_path, shuffle=True)
143-
# dataset = filenames.interleave(load_dataset, cycle_length=1)
144-
dataset = load_dataset(file_path)
145-
if training:
146-
dataset = dataset.shuffle(10000)
147-
dataset = dataset.map(parse_example)
148-
dataset = dataset.batch(hparams.batch_size)
149-
return dataset
150-
151-
return make_dataset(train_data_path, True), make_dataset(test_data_path)
141+
# If the dataset is sharded, the following code may be required:
142+
# filenames = tf.data.Dataset.list_files(file_path, shuffle=True)
143+
# dataset = filenames.interleave(load_dataset, cycle_length=1)
144+
dataset = load_dataset(file_path)
145+
if training:
146+
dataset = dataset.shuffle(10000)
147+
dataset = dataset.map(parse_example)
148+
dataset = dataset.batch(hparams.batch_size)
149+
return dataset
152150

153151

154152
def make_mlp_sequential_model(hparams):
@@ -270,7 +268,8 @@ def main(argv):
270268
(len(argv) - 1))
271269

272270
hparams = get_hyper_parameters()
273-
train_dataset, test_dataset = make_datasets(argv[1], argv[2], hparams)
271+
train_data_path = argv[1]
272+
test_data_path = argv[2]
274273

275274
# Graph regularization configuration.
276275
graph_reg_config = nsl.configs.make_graph_reg_config(
@@ -287,6 +286,8 @@ def main(argv):
287286
}
288287
for base_model_tag, base_model in base_models.items():
289288
logging.info('\n====== %s BASE MODEL TEST BEGIN ======', base_model_tag)
289+
train_dataset = make_dataset(train_data_path, True, False, hparams)
290+
test_dataset = make_dataset(test_data_path, False, False, hparams)
290291
train_and_evaluate(base_model, 'Base MLP model', train_dataset,
291292
test_dataset, hparams)
292293

@@ -295,6 +296,8 @@ def main(argv):
295296
# Wrap the base MLP model with graph regularization.
296297
graph_reg_model = nsl.keras.GraphRegularization(base_model,
297298
graph_reg_config)
299+
train_dataset = make_dataset(train_data_path, True, True, hparams)
300+
test_dataset = make_dataset(test_data_path, False, False, hparams)
298301
train_and_evaluate(graph_reg_model, 'MLP + graph regularization',
299302
train_dataset, test_dataset, hparams)
300303

0 commit comments

Comments
 (0)