Skip to content

Commit 2b68d08

Browse files
committed
Small fixes.
1 parent 0384ee1 commit 2b68d08

File tree

5 files changed

+107
-66
lines changed

5 files changed

+107
-66
lines changed

neural_structured_learning/research/gam/experiments/run_train_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,14 @@
176176
'String representing the number of units of the hidden layers of the '
177177
'aggregation network of the agreement model.')
178178
flags.DEFINE_float(
179-
'weight_decay_cls', 0,
179+
'weight_decay_cls', None,
180180
'Weight of the L2 penalty on the classification model weights.')
181181
flags.DEFINE_string(
182182
'weight_decay_schedule_cls', None,
183183
'Schedule for decaying the weight decay in the classification model. '
184184
'Choose bewteen None or linear.')
185185
flags.DEFINE_float(
186-
'weight_decay_agr', 0,
186+
'weight_decay_agr', None,
187187
'Weight of the L2 penalty on the agreement model weights.')
188188
flags.DEFINE_string(
189189
'weight_decay_schedule_agr', None,

neural_structured_learning/research/gam/models/cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def get_loss(self,
289289
weight decay is applied.
290290
**kwargs: Keyword arguments, potentially containing the weight of the
291291
regularization term, passed under the name `weight_decay`. If this is
292-
not provided, it defaults to 0.0.
292+
not provided, it defaults to 0.004.
293293
294294
Returns:
295295
loss: The cummulated loss value.

neural_structured_learning/research/gam/models/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def get_loss(self,
227227
# Weight decay loss.
228228
if weight_decay is not None:
229229
for var in reg_params.values():
230-
loss += weight_decay * tf.nn.l2_loss(var)
230+
loss = loss + weight_decay * tf.nn.l2_loss(var)
231231
return loss
232232

233233
def normalize_predictions(self, predictions):

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,18 @@ def __init__(self,
221221
# Create train op.
222222
grads_and_vars = self.optimizer.compute_gradients(
223223
loss_op,
224-
tf.trainable_variables())
224+
tf.trainable_variables(scope=tf.get_default_graph().get_name_scope()))
225225
# Clip gradients.
226226
if self.gradient_clip:
227227
variab = [elem[1] for elem in grads_and_vars]
228228
gradients = [elem[0] for elem in grads_and_vars]
229229
gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
230230
grads_and_vars = tuple(zip(gradients, variab))
231-
train_op = self.optimizer.apply_gradients(
232-
grads_and_vars, global_step=self.global_step)
231+
with tf.control_dependencies(
232+
tf.get_collection(tf.GraphKeys.UPDATE_OPS,
233+
scope=tf.get_default_graph().get_name_scope())):
234+
train_op = self.optimizer.apply_gradients(
235+
grads_and_vars, global_step=self.global_step)
233236

234237
# Create Tensorboard summaries.
235238
if self.enable_summaries:
@@ -333,8 +336,11 @@ def _create_weight_decay_var(self, weight_decay_initial,
333336
weight_decay_var = None
334337
weight_decay_update = None
335338
if weight_decay_schedule is None:
336-
weight_decay_var = tf.constant(
337-
weight_decay_initial, dtype=tf.float32, name='weight_decay')
339+
if weight_decay_initial is None:
340+
weight_decay_var = None
341+
else:
342+
weight_decay_var = tf.constant(
343+
weight_decay_initial, dtype=tf.float32, name='weight_decay')
338344
elif weight_decay_schedule == 'linear':
339345
weight_decay_var = tf.get_variable(
340346
name='weight_decay',
@@ -410,6 +416,75 @@ def _eval_random_pairs(self, data, session):
410416
acc = session.run(self.accuracy, feed_dict=feed_dict)
411417
return acc
412418

419+
420+
def _eval_train(self, session, feed_dict):
421+
"""Computes the accuracy of the predictions for the provided batch.
422+
423+
This calculates the accuracy for both class 1 (agreement) and class 0
424+
(disagreement).
425+
426+
Arguments:
427+
session: A TensorFlow session.
428+
feed_dict: A train feed dictionary.
429+
Returns:
430+
The computed train accuracy.
431+
"""
432+
train_acc, pred, targ = session.run(
433+
(self.accuracy, self.normalized_predictions, self.labels),
434+
feed_dict=feed_dict)
435+
# Assume the threshold is at 0.5, and binarize the predictions.
436+
binary_pred = pred > 0.5
437+
targ = targ.astype(np.int32)
438+
acc_per_sample = binary_pred == targ
439+
acc_1 = acc_per_sample[targ == 1]
440+
if acc_1.shape[0] > 0:
441+
acc_1 = sum(acc_1) / np.float32(len(acc_1))
442+
else:
443+
acc_1 = -1
444+
acc_0 = acc_per_sample[targ == 0]
445+
if acc_0.shape[0] > 0:
446+
acc_0 = sum(acc_0) / np.float32(len(acc_0))
447+
else:
448+
acc_0 = -1
449+
logging.info('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f',
450+
train_acc, acc_1, acc_0)
451+
return train_acc
452+
453+
454+
def _eval_validation(self, data, labeled_nodes_val, ratio_pos_to_neg,
455+
num_samples_val, session):
456+
"""Evaluate the current model on validation data.
457+
458+
Args:
459+
data: A CotrainDataset object.
460+
labeled_nodes_val: An array of indices of labeled nodes from which to
461+
sample validation pairs.
462+
ratio_pos_to_neg: The ratio of positive to negative samples, which is
463+
used to keep the samples agreement pairs balanced.
464+
num_samples_val: Number of sample pairs to use for validation. Since the
465+
number of combinations of samples in `labeled_nodes_val` can be very
466+
high, for validation we use only `num_samples_val` pairs.
467+
session: A TensorFlow session.
468+
469+
Returns:
470+
Total accuracy on random pairs of samples.
471+
"""
472+
data_iterator_val = self._pair_iterator(labeled_nodes_val, data,
473+
ratio_pos_neg=ratio_pos_to_neg)
474+
feed_dict_val = self._construct_feed_dict(
475+
data_iterator_val, is_train=False)
476+
cummulative_val_acc = 0.0
477+
samples_seen = 0
478+
while feed_dict_val is not None and samples_seen < num_samples_val:
479+
val_acc, batch_size_actual = session.run(
480+
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
481+
cummulative_val_acc += val_acc * batch_size_actual
482+
samples_seen += batch_size_actual
483+
feed_dict_val = self._construct_feed_dict(
484+
data_iterator_val, is_train=False)
485+
cummulative_val_acc /= samples_seen
486+
return cummulative_val_acc
487+
413488
def _train_iterator(self, labeled_samples, neighbors_val, data,
414489
ratio_pos_to_neg=None):
415490
"""An iterator over pairs of samples for training the agreement model.
@@ -653,19 +728,10 @@ def train(self, data, session=None, **kwargs):
653728
if num_samples_val == 0:
654729
logging.info('Skipping validation. No validation samples available.')
655730
break
656-
data_iterator_val = self._pair_iterator(labeled_nodes_val, data)
657-
feed_dict_val = self._construct_feed_dict(
658-
data_iterator_val, is_train=False)
659-
cummulative_val_acc = 0.0
660-
samples_seen = 0
661-
while feed_dict_val is not None and samples_seen < num_samples_val:
662-
val_acc, batch_size_actual = session.run(
663-
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
664-
cummulative_val_acc += val_acc * batch_size_actual
665-
samples_seen += batch_size_actual
666-
feed_dict_val = self._construct_feed_dict(
667-
data_iterator_val, is_train=False)
668-
cummulative_val_acc /= samples_seen
731+
732+
# Evaluate on the selected validation data.
733+
val_acc = self._eval_validation(
734+
data, labeled_nodes_val, ratio_pos_to_neg, num_samples_val, session)
669735

670736
# Evaluate over a random choice of sample pairs, either labeled or not.
671737
acc_random = self._eval_random_pairs(data, session)
@@ -680,20 +746,20 @@ def train(self, data, session=None, **kwargs):
680746
summary.value.add(tag='AgreementModel/train_acc',
681747
simple_value=acc_train)
682748
summary.value.add(tag='AgreementModel/val_acc',
683-
simple_value=cummulative_val_acc)
749+
simple_value=val_acc)
684750
if acc_random is not None:
685751
summary.value.add(tag='AgreementModel/random_acc',
686752
simple_value=acc_random)
687753
iter_total = session.run(self.iter_agr_total)
688754
summary_writer.add_summary(summary, iter_total)
689755
summary_writer.flush()
690-
if step % self.logging_step == 0 or cummulative_val_acc > best_val_acc:
756+
if step % self.logging_step == 0 or val_acc > best_val_acc:
691757
logging.info(
692758
'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
693759
'random_acc: %10.4f | acc_train: %10.4f', step, loss_val,
694-
cummulative_val_acc, acc_random, acc_train)
695-
if cummulative_val_acc > best_val_acc:
696-
best_val_acc = cummulative_val_acc
760+
val_acc, acc_random, acc_train)
761+
if val_acc > best_val_acc:
762+
best_val_acc = val_acc
697763
if self.checkpoint_path:
698764
self.saver.save(
699765
session, self.checkpoint_path, write_meta_graph=False)
@@ -759,39 +825,6 @@ def predict(self, session, src_features, tgt_features, **unused_kwargs):
759825
# Predict always disagreement.
760826
return np.zeros(shape=(len(src_features),), dtype=np.float32)
761827

762-
def _eval_train(self, session, feed_dict):
763-
"""Computes the accuracy of the predictions for the provided batch.
764-
765-
This calculates the accuracy for both class 1 (agreement) and class 0
766-
(disagreement).
767-
768-
Arguments:
769-
session: A TensorFlow session.
770-
feed_dict: A train feed dictionary.
771-
Returns:
772-
The computed train accuracy.
773-
"""
774-
train_acc, pred, targ = session.run(
775-
(self.accuracy, self.normalized_predictions, self.labels),
776-
feed_dict=feed_dict)
777-
# Assume the threshold is at 0.5, and binarize the predictions.
778-
binary_pred = pred > 0.5
779-
targ = targ.astype(np.int32)
780-
acc_per_sample = binary_pred == targ
781-
acc_1 = acc_per_sample[targ == 1]
782-
if acc_1.shape[0] > 0:
783-
acc_1 = sum(acc_1) / np.float32(len(acc_1))
784-
else:
785-
acc_1 = -1
786-
acc_0 = acc_per_sample[targ == 0]
787-
if acc_0.shape[0] > 0:
788-
acc_0 = sum(acc_0) / np.float32(len(acc_0))
789-
else:
790-
acc_0 = -1
791-
logging.info('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f',
792-
train_acc, acc_1, acc_0)
793-
return train_acc
794-
795828
def predict_label_by_agreement(self, session, indices, num_neighbors=100):
796829
"""Predict class labels using agreement with other labeled samples.
797830

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def __init__(self,
217217
loss_supervised = tf.reduce_mean(loss_supervised)
218218
else:
219219
loss_supervised = self.model.get_loss(predictions=predictions,
220-
targets=one_hot_labels)
220+
targets=one_hot_labels,
221+
weight_decay=None)
221222

222223
# Agreement regularization loss.
223224
loss_agr = self._get_agreement_reg_loss(data, is_train, features_shape)
@@ -229,8 +230,9 @@ def __init__(self,
229230

230231
# Weight decay loss.
231232
loss_reg = 0.0
232-
for var in reg_params.values():
233-
loss_reg += weight_decay_var * tf.nn.l2_loss(var)
233+
if weight_decay_var is not None:
234+
for var in reg_params.values():
235+
loss_reg += weight_decay_var * tf.nn.l2_loss(var)
234236

235237
# Total loss.
236238
loss_op = loss_supervised + loss_agr + loss_reg
@@ -272,8 +274,11 @@ def __init__(self,
272274
gradients = [elem[0] for elem in grads_and_vars]
273275
gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
274276
grads_and_vars = tuple(zip(gradients, variab))
275-
train_op = self.optimizer.apply_gradients(
276-
grads_and_vars, global_step=self.global_step)
277+
with tf.control_dependencies(
278+
tf.get_collection(tf.GraphKeys.UPDATE_OPS,
279+
scope=tf.get_default_graph().get_name_scope())):
280+
train_op = self.optimizer.apply_gradients(
281+
grads_and_vars, global_step=self.global_step)
277282

278283
# Create a saver for model variables.
279284
trainable_vars = [v for _, v in grads_and_vars]
@@ -320,8 +325,11 @@ def _create_weight_decay_var(self, weight_decay_initial,
320325
weight_decay_var = None
321326
weight_decay_update = None
322327
if weight_decay_schedule is None:
323-
weight_decay_var = tf.constant(
328+
if weight_decay_initial is not None:
329+
weight_decay_var = tf.constant(
324330
weight_decay_initial, dtype=tf.float32, name='weight_decay')
331+
else:
332+
weight_decay_var = None
325333
elif weight_decay_schedule == 'linear':
326334
weight_decay_var = tf.get_variable(
327335
name='weight_decay',

0 commit comments

Comments
 (0)