@@ -221,15 +221,18 @@ def __init__(self,
221221 # Create train op.
222222 grads_and_vars = self .optimizer .compute_gradients (
223223 loss_op ,
224- tf .trainable_variables ())
224+ tf .trainable_variables (scope = tf . get_default_graph (). get_name_scope () ))
225225 # Clip gradients.
226226 if self .gradient_clip :
227227 variab = [elem [1 ] for elem in grads_and_vars ]
228228 gradients = [elem [0 ] for elem in grads_and_vars ]
229229 gradients , _ = tf .clip_by_global_norm (gradients , self .gradient_clip )
230230 grads_and_vars = tuple (zip (gradients , variab ))
231- train_op = self .optimizer .apply_gradients (
232- grads_and_vars , global_step = self .global_step )
231+ with tf .control_dependencies (
232+ tf .get_collection (tf .GraphKeys .UPDATE_OPS ,
233+ scope = tf .get_default_graph ().get_name_scope ())):
234+ train_op = self .optimizer .apply_gradients (
235+ grads_and_vars , global_step = self .global_step )
233236
234237 # Create Tensorboard summaries.
235238 if self .enable_summaries :
@@ -333,8 +336,11 @@ def _create_weight_decay_var(self, weight_decay_initial,
333336 weight_decay_var = None
334337 weight_decay_update = None
335338 if weight_decay_schedule is None :
336- weight_decay_var = tf .constant (
337- weight_decay_initial , dtype = tf .float32 , name = 'weight_decay' )
339+ if weight_decay_initial is None :
340+ weight_decay_var = None
341+ else :
342+ weight_decay_var = tf .constant (
343+ weight_decay_initial , dtype = tf .float32 , name = 'weight_decay' )
338344 elif weight_decay_schedule == 'linear' :
339345 weight_decay_var = tf .get_variable (
340346 name = 'weight_decay' ,
@@ -410,6 +416,75 @@ def _eval_random_pairs(self, data, session):
410416 acc = session .run (self .accuracy , feed_dict = feed_dict )
411417 return acc
412418
419+
420+ def _eval_train (self , session , feed_dict ):
421+ """Computes the accuracy of the predictions for the provided batch.
422+
423+ This calculates the accuracy for both class 1 (agreement) and class 0
424+ (disagreement).
425+
426+ Arguments:
427+ session: A TensorFlow session.
428+ feed_dict: A train feed dictionary.
429+ Returns:
430+ The computed train accuracy.
431+ """
432+ train_acc , pred , targ = session .run (
433+ (self .accuracy , self .normalized_predictions , self .labels ),
434+ feed_dict = feed_dict )
435+ # Assume the threshold is at 0.5, and binarize the predictions.
436+ binary_pred = pred > 0.5
437+ targ = targ .astype (np .int32 )
438+ acc_per_sample = binary_pred == targ
439+ acc_1 = acc_per_sample [targ == 1 ]
440+ if acc_1 .shape [0 ] > 0 :
441+ acc_1 = sum (acc_1 ) / np .float32 (len (acc_1 ))
442+ else :
443+ acc_1 = - 1
444+ acc_0 = acc_per_sample [targ == 0 ]
445+ if acc_0 .shape [0 ] > 0 :
446+ acc_0 = sum (acc_0 ) / np .float32 (len (acc_0 ))
447+ else :
448+ acc_0 = - 1
449+ logging .info ('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f' ,
450+ train_acc , acc_1 , acc_0 )
451+ return train_acc
452+
453+
454+ def _eval_validation (self , data , labeled_nodes_val , ratio_pos_to_neg ,
455+ num_samples_val , session ):
456+ """Evaluate the current model on validation data.
457+
458+ Args:
459+ data: A CotrainDataset object.
460+ labeled_nodes_val: An array of indices of labeled nodes from which to
461+ sample validation pairs.
462+ ratio_pos_to_neg: The ratio of positive to negative samples, which is
463+ used to keep the samples agreement pairs balanced.
464+ num_samples_val: Number of sample pairs to use for validation. Since the
465+ number of combinations of samples in `labeled_nodes_val` can be very
466+ high, for validation we use only `num_samples_val` pairs.
467+ session: A TensorFlow session.
468+
469+ Returns:
470+ Total accuracy on random pairs of samples.
471+ """
472+ data_iterator_val = self ._pair_iterator (labeled_nodes_val , data ,
473+ ratio_pos_neg = ratio_pos_to_neg )
474+ feed_dict_val = self ._construct_feed_dict (
475+ data_iterator_val , is_train = False )
476+ cummulative_val_acc = 0.0
477+ samples_seen = 0
478+ while feed_dict_val is not None and samples_seen < num_samples_val :
479+ val_acc , batch_size_actual = session .run (
480+ (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
481+ cummulative_val_acc += val_acc * batch_size_actual
482+ samples_seen += batch_size_actual
483+ feed_dict_val = self ._construct_feed_dict (
484+ data_iterator_val , is_train = False )
485+ cummulative_val_acc /= samples_seen
486+ return cummulative_val_acc
487+
413488 def _train_iterator (self , labeled_samples , neighbors_val , data ,
414489 ratio_pos_to_neg = None ):
415490 """An iterator over pairs of samples for training the agreement model.
@@ -653,19 +728,10 @@ def train(self, data, session=None, **kwargs):
653728 if num_samples_val == 0 :
654729 logging .info ('Skipping validation. No validation samples available.' )
655730 break
656- data_iterator_val = self ._pair_iterator (labeled_nodes_val , data )
657- feed_dict_val = self ._construct_feed_dict (
658- data_iterator_val , is_train = False )
659- cummulative_val_acc = 0.0
660- samples_seen = 0
661- while feed_dict_val is not None and samples_seen < num_samples_val :
662- val_acc , batch_size_actual = session .run (
663- (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
664- cummulative_val_acc += val_acc * batch_size_actual
665- samples_seen += batch_size_actual
666- feed_dict_val = self ._construct_feed_dict (
667- data_iterator_val , is_train = False )
668- cummulative_val_acc /= samples_seen
731+
732+ # Evaluate on the selected validation data.
733+ val_acc = self ._eval_validation (
734+ data , labeled_nodes_val , ratio_pos_to_neg , num_samples_val , session )
669735
670736 # Evaluate over a random choice of sample pairs, either labeled or not.
671737 acc_random = self ._eval_random_pairs (data , session )
@@ -680,20 +746,20 @@ def train(self, data, session=None, **kwargs):
680746 summary .value .add (tag = 'AgreementModel/train_acc' ,
681747 simple_value = acc_train )
682748 summary .value .add (tag = 'AgreementModel/val_acc' ,
683- simple_value = cummulative_val_acc )
749+ simple_value = val_acc )
684750 if acc_random is not None :
685751 summary .value .add (tag = 'AgreementModel/random_acc' ,
686752 simple_value = acc_random )
687753 iter_total = session .run (self .iter_agr_total )
688754 summary_writer .add_summary (summary , iter_total )
689755 summary_writer .flush ()
690- if step % self .logging_step == 0 or cummulative_val_acc > best_val_acc :
756+ if step % self .logging_step == 0 or val_acc > best_val_acc :
691757 logging .info (
692758 'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
693759 'random_acc: %10.4f | acc_train: %10.4f' , step , loss_val ,
694- cummulative_val_acc , acc_random , acc_train )
695- if cummulative_val_acc > best_val_acc :
696- best_val_acc = cummulative_val_acc
760+ val_acc , acc_random , acc_train )
761+ if val_acc > best_val_acc :
762+ best_val_acc = val_acc
697763 if self .checkpoint_path :
698764 self .saver .save (
699765 session , self .checkpoint_path , write_meta_graph = False )
@@ -759,39 +825,6 @@ def predict(self, session, src_features, tgt_features, **unused_kwargs):
759825 # Predict always disagreement.
760826 return np .zeros (shape = (len (src_features ),), dtype = np .float32 )
761827
762- def _eval_train (self , session , feed_dict ):
763- """Computes the accuracy of the predictions for the provided batch.
764-
765- This calculates the accuracy for both class 1 (agreement) and class 0
766- (disagreement).
767-
768- Arguments:
769- session: A TensorFlow session.
770- feed_dict: A train feed dictionary.
771- Returns:
772- The computed train accuracy.
773- """
774- train_acc , pred , targ = session .run (
775- (self .accuracy , self .normalized_predictions , self .labels ),
776- feed_dict = feed_dict )
777- # Assume the threshold is at 0.5, and binarize the predictions.
778- binary_pred = pred > 0.5
779- targ = targ .astype (np .int32 )
780- acc_per_sample = binary_pred == targ
781- acc_1 = acc_per_sample [targ == 1 ]
782- if acc_1 .shape [0 ] > 0 :
783- acc_1 = sum (acc_1 ) / np .float32 (len (acc_1 ))
784- else :
785- acc_1 = - 1
786- acc_0 = acc_per_sample [targ == 0 ]
787- if acc_0 .shape [0 ] > 0 :
788- acc_0 = sum (acc_0 ) / np .float32 (len (acc_0 ))
789- else :
790- acc_0 = - 1
791- logging .info ('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f' ,
792- train_acc , acc_1 , acc_0 )
793- return train_acc
794-
795828 def predict_label_by_agreement (self , session , indices , num_neighbors = 100 ):
796829 """Predict class labels using agreement with other labeled samples.
797830
0 commit comments