Skip to content

Commit dca1457

Browse files
committed
Refactoring.
1 parent 3265019 commit dca1457

File tree

2 files changed

+62
-54
lines changed

2 files changed

+62
-54
lines changed

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
461461
if self.penalize_neg_agr:
462462
# Since the agreement is predicting scores between [0, 1], anything
463463
# under 0.5 should represent disagreement. Therefore, we want to encourage
464-
# agreement whenever the score is > 0.5, otherwise don't incurr any loss.
464+
# agreement whenever the score is > 0.5, otherwise don't incur any loss.
465465
agreement = tf.nn.relu(agreement - 0.5)
466466

467467
# Create a Tensor containing the weights assigned to each pair in the
@@ -497,7 +497,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
497497

498498
def _construct_feed_dict(self,
499499
data_iterator,
500-
is_train,
500+
split,
501501
pair_ll_iterator=None,
502502
pair_lu_iterator=None,
503503
pair_uu_iterator=None):
@@ -506,12 +506,12 @@ def _construct_feed_dict(self,
506506
input_indices = next(data_iterator)
507507
# Select the labels. Use the true, correct labels, at test time, and the
508508
# self-labeled ones at train time.
509-
labels = (self.data.get_labels(input_indices) if is_train else
510-
self.data.get_original_labels(input_indices))
509+
labels = (self.data.get_original_labels(input_indices) if split == 'test'
510+
else self.data.get_labels(input_indices))
511511
feed_dict = {
512512
self.input_features: self.data.get_features(input_indices),
513513
self.input_labels: labels,
514-
self.is_train: is_train
514+
self.is_train: split == 'train'
515515
}
516516
if pair_ll_iterator is not None:
517517
_, _, _, features_tgt, labels_src, labels_tgt = next(pair_ll_iterator)
@@ -584,6 +584,37 @@ def _select_from_pool(indices):
584584
yield (indices_src, indices_tgt, features_src, features_tgt,
585585
labels_src, labels_tgt)
586586

587+
def _evaluate(self, indices, split, session, summary_writer):
588+
"""Evaluates the samples with the provided indices."""
589+
data_iterator_val = batch_iterator(
590+
indices,
591+
batch_size=self.batch_size,
592+
shuffle=False,
593+
allow_smaller_batch=True,
594+
repeat=False)
595+
feed_dict_val = self._construct_feed_dict(data_iterator_val, split)
596+
cummulative_acc = 0.0
597+
num_samples = 0
598+
while feed_dict_val is not None:
599+
val_acc, batch_size_actual = session.run(
600+
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
601+
cummulative_acc += val_acc * batch_size_actual
602+
num_samples += batch_size_actual
603+
feed_dict_val = self._construct_feed_dict(data_iterator_val, split)
604+
if num_samples > 0:
605+
cummulative_acc /= num_samples
606+
607+
if self.enable_summaries:
608+
summary = tf.Summary()
609+
summary.value.add(
610+
tag='ClassificationModel/' + split + '_acc',
611+
simple_value=cummulative_acc)
612+
iter_cls_total = session.run(self.iter_cls_total)
613+
summary_writer.add_summary(summary, iter_cls_total)
614+
summary_writer.flush()
615+
616+
return cummulative_acc
617+
587618
def train(self, data, session=None, **kwargs):
588619
"""Train the classification model on the provided dataset.
589620
@@ -646,68 +677,37 @@ def train(self, data, session=None, **kwargs):
646677
best_val_acc = -1
647678
checkpoint_saved = False
648679
while not has_converged:
649-
feed_dict = self._construct_feed_dict(data_iterator_train, True,
650-
pair_ll_iterator, pair_lu_iterator,
651-
pair_uu_iterator)
680+
feed_dict = self._construct_feed_dict(
681+
data_iterator=data_iterator_train,
682+
split='train',
683+
pair_ll_iterator=pair_ll_iterator,
684+
pair_lu_iterator=pair_lu_iterator,
685+
pair_uu_iterator=pair_uu_iterator)
652686
if self.enable_summaries and step % self.summary_step == 0:
653-
loss_val, summary, _ = session.run(
654-
[self.loss_op, self.summary_op, self.train_op],
687+
loss_val, summary, iter_cls_total, _ = session.run(
688+
[self.loss_op, self.summary_op, self.iter_cls_total, self.train_op],
655689
feed_dict=feed_dict)
656-
iter_cls_total = session.run(self.iter_cls_total)
657690
summary_writer.add_summary(summary, iter_cls_total)
658691
summary_writer.flush()
659692
else:
660-
loss_val, _ = session.run((self.loss_op, self.train_op),
661-
feed_dict=feed_dict)
693+
loss_val, _ = session.run(
694+
(self.loss_op, self.train_op), feed_dict=feed_dict)
662695

663696
# Log the loss, if necessary.
664697
if step % self.logging_step == 0:
665698
logging.info('Classification step %6d | Loss: %10.4f', step, loss_val)
666699

667700
# Evaluate, if necessary.
668-
def _evaluate(indices, name):
669-
"""Evaluates the samples with the provided indices."""
670-
data_iterator_val = batch_iterator(
671-
indices,
672-
batch_size=self.batch_size,
673-
shuffle=False,
674-
allow_smaller_batch=True,
675-
repeat=False)
676-
feed_dict_val = self._construct_feed_dict(data_iterator_val, False)
677-
cummulative_acc = 0.0
678-
num_samples = 0
679-
while feed_dict_val is not None:
680-
val_acc, batch_size_actual = session.run(
681-
(self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
682-
cummulative_acc += val_acc * batch_size_actual
683-
num_samples += batch_size_actual
684-
feed_dict_val = self._construct_feed_dict(data_iterator_val, False)
685-
if num_samples > 0:
686-
cummulative_acc /= num_samples
687-
688-
if self.enable_summaries:
689-
summary = tf.Summary()
690-
summary.value.add(
691-
tag='ClassificationModel/' + name + '_acc',
692-
simple_value=cummulative_acc)
693-
iter_cls_total = session.run(self.iter_cls_total)
694-
summary_writer.add_summary(summary, iter_cls_total)
695-
summary_writer.flush()
696-
697-
return cummulative_acc
698-
699-
# Run validation, if necessary.
700701
if step % self.eval_step == 0:
701702
logging.info('Evaluating on %d validation samples...', len(val_indices))
702-
val_acc = _evaluate(val_indices, 'val_acc')
703+
val_acc = self._evaluate(val_indices, 'val', session, summary_writer)
703704
logging.info('Evaluating on %d test samples...', len(test_indices))
704-
test_acc = _evaluate(test_indices, 'test_acc')
705+
test_acc = self._evaluate(test_indices, 'test', session, summary_writer)
705706

706707
if step % self.logging_step == 0 or val_acc > best_val_acc:
707708
logging.info(
708-
'Classification step %6d | Loss: %10.4f | '
709-
'val_acc: %10.4f | test_acc: %10.4f', step, loss_val, val_acc,
710-
test_acc)
709+
'Classification step %6d | Loss: %10.4f | val_acc: %10.4f | '
710+
'test_acc: %10.4f', step, loss_val, val_acc, test_acc)
711711
if val_acc > best_val_acc:
712712
best_val_acc = val_acc
713713
best_test_acc = test_acc
@@ -738,9 +738,15 @@ def _evaluate(indices, name):
738738
logging.info('Restoring best model...')
739739
self.saver.restore(session, self.checkpoint_path)
740740

741+
########################
742+
# TEST
743+
test_acc = self._evaluate(test_indices, 'test', session, summary_writer)
744+
print('\n\nTest acc after restoration: %f. Test at best val: %f \n\n' % (test_acc, best_test_acc))
745+
##########################
746+
741747
return best_test_acc, best_val_acc
742748

743-
def predict(self, session, indices):
749+
def predict(self, session, indices, is_train):
744750
"""Make predictions for the provided sample indices."""
745751
num_inputs = len(indices)
746752
idx_start = 0
@@ -751,7 +757,8 @@ def predict(self, session, indices):
751757
input_features = self.data.get_features(batch_indices)
752758
batch_predictions = session.run(
753759
self.normalized_predictions,
754-
feed_dict={self.input_features: input_features})
760+
feed_dict={self.input_features: input_features,
761+
self.is_train:is_train})
755762
predictions.append(batch_predictions)
756763
idx_start = idx_end
757764
if not predictions:
@@ -770,7 +777,7 @@ def train(self, unused_data, unused_session=None, **unused_kwargs):
770777
logging.info('Perfect classifier, no need to train...')
771778
return 1.0, 1.0
772779

773-
def predict(self, unused_session, indices_unlabeled):
780+
def predict(self, unused_session, indices_unlabeled, **unused_kwargs):
774781
labels = self.data.get_original_labels(indices_unlabeled)
775782
num_samples = len(indices_unlabeled)
776783
predictions = np.zeros((num_samples, self.data.num_classes))

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def _select_samples_to_label(self, data, trainer_cls, session):
334334
"""
335335
# Select the candidate samples for self-labeling, and make predictions.
336336
indices_unlabeled = data.get_indices_unlabeled()
337-
predictions = trainer_cls.predict(session, indices_unlabeled)
337+
predictions = trainer_cls.predict(
338+
session, indices_unlabeled, is_train=False)
338339
# Select most confident nodes. Compute confidence and most confident label,
339340
# which will be used as the new label.
340341
predicted_label = np.argmax(predictions, axis=-1)

0 commit comments

Comments
 (0)