@@ -461,7 +461,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
461461 if self .penalize_neg_agr :
462462 # Since the agreement is predicting scores between [0, 1], anything
463463 # under 0.5 should represent disagreement. Therefore, we want to encourage
464- # agreement whenever the score is > 0.5, otherwise don't incurr any loss.
464+ # agreement whenever the score is > 0.5, otherwise don't incur any loss.
465465 agreement = tf .nn .relu (agreement - 0.5 )
466466
467467 # Create a Tensor containing the weights assigned to each pair in the
@@ -497,7 +497,7 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
497497
498498 def _construct_feed_dict (self ,
499499 data_iterator ,
500- is_train ,
500+ split ,
501501 pair_ll_iterator = None ,
502502 pair_lu_iterator = None ,
503503 pair_uu_iterator = None ):
@@ -506,12 +506,12 @@ def _construct_feed_dict(self,
506506 input_indices = next (data_iterator )
507507 # Select the labels. Use the true, correct labels, at test time, and the
508508 # self-labeled ones at train time.
509- labels = (self .data .get_labels (input_indices ) if is_train else
510- self .data .get_original_labels (input_indices ))
509+ labels = (self .data .get_original_labels (input_indices ) if split == 'test'
510+ else self .data .get_labels (input_indices ))
511511 feed_dict = {
512512 self .input_features : self .data .get_features (input_indices ),
513513 self .input_labels : labels ,
514- self .is_train : is_train
514+ self .is_train : split == 'train'
515515 }
516516 if pair_ll_iterator is not None :
517517 _ , _ , _ , features_tgt , labels_src , labels_tgt = next (pair_ll_iterator )
@@ -584,6 +584,37 @@ def _select_from_pool(indices):
584584 yield (indices_src , indices_tgt , features_src , features_tgt ,
585585 labels_src , labels_tgt )
586586
587+ def _evaluate (self , indices , split , session , summary_writer ):
588+ """Evaluates the samples with the provided indices."""
589+ data_iterator_val = batch_iterator (
590+ indices ,
591+ batch_size = self .batch_size ,
592+ shuffle = False ,
593+ allow_smaller_batch = True ,
594+ repeat = False )
595+ feed_dict_val = self ._construct_feed_dict (data_iterator_val , split )
596+ cummulative_acc = 0.0
597+ num_samples = 0
598+ while feed_dict_val is not None :
599+ val_acc , batch_size_actual = session .run (
600+ (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
601+ cummulative_acc += val_acc * batch_size_actual
602+ num_samples += batch_size_actual
603+ feed_dict_val = self ._construct_feed_dict (data_iterator_val , split )
604+ if num_samples > 0 :
605+ cummulative_acc /= num_samples
606+
607+ if self .enable_summaries :
608+ summary = tf .Summary ()
609+ summary .value .add (
610+ tag = 'ClassificationModel/' + split + '_acc' ,
611+ simple_value = cummulative_acc )
612+ iter_cls_total = session .run (self .iter_cls_total )
613+ summary_writer .add_summary (summary , iter_cls_total )
614+ summary_writer .flush ()
615+
616+ return cummulative_acc
617+
587618 def train (self , data , session = None , ** kwargs ):
588619 """Train the classification model on the provided dataset.
589620
@@ -646,68 +677,37 @@ def train(self, data, session=None, **kwargs):
646677 best_val_acc = - 1
647678 checkpoint_saved = False
648679 while not has_converged :
649- feed_dict = self ._construct_feed_dict (data_iterator_train , True ,
650- pair_ll_iterator , pair_lu_iterator ,
651- pair_uu_iterator )
680+ feed_dict = self ._construct_feed_dict (
681+ data_iterator = data_iterator_train ,
682+ split = 'train' ,
683+ pair_ll_iterator = pair_ll_iterator ,
684+ pair_lu_iterator = pair_lu_iterator ,
685+ pair_uu_iterator = pair_uu_iterator )
652686 if self .enable_summaries and step % self .summary_step == 0 :
653- loss_val , summary , _ = session .run (
654- [self .loss_op , self .summary_op , self .train_op ],
687+ loss_val , summary , iter_cls_total , _ = session .run (
688+ [self .loss_op , self .summary_op , self .iter_cls_total , self . train_op ],
655689 feed_dict = feed_dict )
656- iter_cls_total = session .run (self .iter_cls_total )
657690 summary_writer .add_summary (summary , iter_cls_total )
658691 summary_writer .flush ()
659692 else :
660- loss_val , _ = session .run (( self . loss_op , self . train_op ),
661- feed_dict = feed_dict )
693+ loss_val , _ = session .run (
694+ ( self . loss_op , self . train_op ), feed_dict = feed_dict )
662695
663696 # Log the loss, if necessary.
664697 if step % self .logging_step == 0 :
665698 logging .info ('Classification step %6d | Loss: %10.4f' , step , loss_val )
666699
667700 # Evaluate, if necessary.
668- def _evaluate (indices , name ):
669- """Evaluates the samples with the provided indices."""
670- data_iterator_val = batch_iterator (
671- indices ,
672- batch_size = self .batch_size ,
673- shuffle = False ,
674- allow_smaller_batch = True ,
675- repeat = False )
676- feed_dict_val = self ._construct_feed_dict (data_iterator_val , False )
677- cummulative_acc = 0.0
678- num_samples = 0
679- while feed_dict_val is not None :
680- val_acc , batch_size_actual = session .run (
681- (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
682- cummulative_acc += val_acc * batch_size_actual
683- num_samples += batch_size_actual
684- feed_dict_val = self ._construct_feed_dict (data_iterator_val , False )
685- if num_samples > 0 :
686- cummulative_acc /= num_samples
687-
688- if self .enable_summaries :
689- summary = tf .Summary ()
690- summary .value .add (
691- tag = 'ClassificationModel/' + name + '_acc' ,
692- simple_value = cummulative_acc )
693- iter_cls_total = session .run (self .iter_cls_total )
694- summary_writer .add_summary (summary , iter_cls_total )
695- summary_writer .flush ()
696-
697- return cummulative_acc
698-
699- # Run validation, if necessary.
700701 if step % self .eval_step == 0 :
701702 logging .info ('Evaluating on %d validation samples...' , len (val_indices ))
702- val_acc = _evaluate (val_indices , 'val_acc' )
703+ val_acc = self . _evaluate (val_indices , 'val' , session , summary_writer )
703704 logging .info ('Evaluating on %d test samples...' , len (test_indices ))
704- test_acc = _evaluate (test_indices , 'test_acc' )
705+ test_acc = self . _evaluate (test_indices , 'test' , session , summary_writer )
705706
706707 if step % self .logging_step == 0 or val_acc > best_val_acc :
707708 logging .info (
708- 'Classification step %6d | Loss: %10.4f | '
709- 'val_acc: %10.4f | test_acc: %10.4f' , step , loss_val , val_acc ,
710- test_acc )
709+ 'Classification step %6d | Loss: %10.4f | val_acc: %10.4f | '
710+ 'test_acc: %10.4f' , step , loss_val , val_acc , test_acc )
711711 if val_acc > best_val_acc :
712712 best_val_acc = val_acc
713713 best_test_acc = test_acc
@@ -738,9 +738,15 @@ def _evaluate(indices, name):
738738 logging .info ('Restoring best model...' )
739739 self .saver .restore (session , self .checkpoint_path )
740740
741+ ########################
742+ # TEST
743+ test_acc = self ._evaluate (test_indices , 'test' , session , summary_writer )
744+ print ('\n \n Test acc after restoration: %f. Test at best val: %f \n \n ' % (test_acc , best_test_acc ))
745+ ##########################
746+
741747 return best_test_acc , best_val_acc
742748
743- def predict (self , session , indices ):
749+ def predict (self , session , indices , is_train ):
744750 """Make predictions for the provided sample indices."""
745751 num_inputs = len (indices )
746752 idx_start = 0
@@ -751,7 +757,8 @@ def predict(self, session, indices):
751757 input_features = self .data .get_features (batch_indices )
752758 batch_predictions = session .run (
753759 self .normalized_predictions ,
754- feed_dict = {self .input_features : input_features })
760+ feed_dict = {self .input_features : input_features ,
761+ self .is_train :is_train })
755762 predictions .append (batch_predictions )
756763 idx_start = idx_end
757764 if not predictions :
@@ -770,7 +777,7 @@ def train(self, unused_data, unused_session=None, **unused_kwargs):
770777 logging .info ('Perfect classifier, no need to train...' )
771778 return 1.0 , 1.0
772779
773- def predict (self , unused_session , indices_unlabeled ):
780+ def predict (self , unused_session , indices_unlabeled , ** unused_kwargs ):
774781 labels = self .data .get_original_labels (indices_unlabeled )
775782 num_samples = len (indices_unlabeled )
776783 predictions = np .zeros ((num_samples , self .data .num_classes ))
0 commit comments