Skip to content

Commit 4014a76

Browse files
csferngtensorflow-copybara
authored andcommitted
Minor improvements on examples.
PiperOrigin-RevId: 283798223
1 parent b88347f commit 4014a76

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

neural_structured_learning/examples/adv_keras_cnn_mnist.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
FLAGS = flags.FLAGS
3838

3939
flags.DEFINE_integer('epochs', None, 'Number of epochs to train.')
40+
flags.DEFINE_integer('steps_per_epoch', None,
41+
'Number of steps in each training epoch.')
42+
flags.DEFINE_integer('eval_steps', None, 'Number of steps to evaluate.')
4043
flags.DEFINE_float('adv_step_size', None,
4144
'Step size for generating adversarial examples.')
4245

@@ -62,14 +65,21 @@ class HParams(object):
6265
batch_size = attr.ib(default=32)
6366
buffer_size = attr.ib(default=10000)
6467
epochs = attr.ib(default=5)
68+
steps_per_epoch = attr.ib(default=None)
69+
eval_steps = attr.ib(default=None)
6570

6671

6772
def get_hparams():
73+
"""Returns the hyperparameters with defaults overwritten by flags."""
6874
hparams = HParams()
6975
if FLAGS.epochs:
7076
hparams.epochs = FLAGS.epochs
7177
if FLAGS.adv_step_size:
7278
hparams.adv_step_size = FLAGS.adv_step_size
79+
if FLAGS.steps_per_epoch:
80+
hparams.steps_per_epoch = FLAGS.steps_per_epoch
81+
if FLAGS.eval_steps:
82+
hparams.eval_steps = FLAGS.eval_steps
7383
return hparams
7484

7585

@@ -136,8 +146,11 @@ def train_and_evaluate(model, hparams, train_dataset, test_dataset):
136146
optimizer='adam',
137147
loss='sparse_categorical_crossentropy',
138148
metrics=['accuracy'])
139-
model.fit(train_dataset, epochs=hparams.epochs)
140-
eval_result = model.evaluate(test_dataset)
149+
model.fit(
150+
train_dataset,
151+
epochs=hparams.epochs,
152+
steps_per_epoch=hparams.steps_per_epoch)
153+
eval_result = model.evaluate(test_dataset, steps=hparams.eval_steps)
141154
return list(zip(model.metrics_names, eval_result))
142155

143156

@@ -169,6 +182,8 @@ def evaluate_robustness(model_to_attack, dataset, models, hparams):
169182
for name in models.keys()
170183
}
171184

185+
if hparams.eval_steps:
186+
dataset = dataset.take(hparams.eval_steps)
172187
# When running on accelerators, looping over the dataset inside a tf.function
173188
# may be much faster.
174189
for batch in dataset:

neural_structured_learning/examples/graph_keras_mlp_cora.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
FLAGS = flags.FLAGS
4444
FLAGS.showprefixforinfo = False
4545

46+
flags.DEFINE_integer('train_epochs', None, 'Number of epochs to train.')
47+
flags.DEFINE_integer('eval_steps', None, 'Number of steps to evaluate.')
48+
4649
NBR_FEATURE_PREFIX = 'NL_nbr_'
4750
NBR_WEIGHT_SUFFIX = '_weight'
4851

@@ -69,7 +72,12 @@ class HParams(object):
6972

7073
def get_hyper_parameters():
7174
"""Returns the hyper-parameters used for training."""
72-
return HParams()
75+
hparams = HParams()
76+
if FLAGS.train_epochs:
77+
hparams.train_epochs = FLAGS.train_epochs
78+
if FLAGS.eval_steps:
79+
hparams.eval_steps = FLAGS.eval_steps
80+
return hparams
7381

7482

7583
def load_dataset(filename):

neural_structured_learning/examples/preprocess/cora/prep_data.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tar -C ${DATA_DIR} -xvzf ${DATA_DIR}/cora.tgz
3838

3939
# Pre-process cora dataset. The file 'preprocess_cora_dataset.py' is assumed to
4040
# be located in the current directory.
41-
python preprocess_cora_dataset.py \
41+
python $(dirname "$0")/preprocess_cora_dataset.py \
4242
--input_cora_content=${DATA_DIR}/cora/cora.content \
4343
--input_cora_graph=${DATA_DIR}/cora/cora.cites \
4444
--max_nbrs=5 \

0 commit comments

Comments
 (0)