@@ -614,16 +614,12 @@ def train(self, data, session=None, **kwargs):
614614 # Compute ratio of positives to negative samples.
615615 labeled_samples_labels = data .get_labels (labeled_samples )
616616 ratio_pos_to_neg = self ._compute_ratio_pos_neg (labeled_samples_labels )
617- # Select a validation set out of all pairs of labeled samples.
618- # TODO: remove this.
619- # neighbors_val, agreement_labels_val = self._select_val_set(
620- # labeled_samples, num_samples_val, data, ratio_pos_to_neg)
621- # Create a train iterator that potentially excludes the validation samples.
622- # data_iterator_train = self._train_iterator(
623- # labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg)
624617
618+ # Split data into train and validation.
625619 labeled_samples_train , labeled_nodes_val = self ._select_val_samples (
626620 labeled_samples , self .ratio_val )
621+
622+ # Create an iterator over training data pairs.
627623 data_iterator_train = self ._pair_iterator (labeled_samples_train , data ,
628624 ratio_pos_neg = ratio_pos_to_neg )
629625
@@ -659,14 +655,6 @@ def train(self, data, session=None, **kwargs):
659655 if num_samples_val == 0 :
660656 logging .info ('Skipping validation. No validation samples available.' )
661657 break
662- # TODO: remove this.
663- # data_iterator_val = batch_iterator(
664- # neighbors_val,
665- # agreement_labels_val,
666- # self.batch_size,
667- # shuffle=False,
668- # allow_smaller_batch=True,
669- # repeat=False)
670658 data_iterator_val = self ._pair_iterator (labeled_nodes_val , data )
671659 feed_dict_val = self ._construct_feed_dict (
672660 data_iterator_val , is_train = False )
@@ -880,7 +868,28 @@ def predict_label_by_agreement(self, session, indices, num_neighbors=100):
880868 return acc
881869
882870 def _pair_iterator (self , labeled_nodes , data , ratio_pos_neg = None ):
883- # TODO: add documentation and rename neighbors to samples.
871+ """An iterator over pairs of samples for training the agreement model.
872+
873+ Provides batches of node pairs, including their features and the agreement
874+ label (i.e. whether their labels agree).
875+
876+ Arguments:
877+ labeled_nodes: An array of integers representing the indices of the
878+ labeled samples.
879+ data: A Dataset object used to provided the labels of the labeled samples.
880+ ratio_pos_neg: A float representing the ratio of positive to negative
881+ samples in the training set. If this is provided, the train iterator
882+ will do rejection sampling based on this ratio to keep the training
883+ data balanced. If None, we sample uniformly.
884+
885+ Yields:
886+ neighbors_batch: An array of shape (batch_size, 2), where each row
887+ represents a pair of sample indices used for training. It will not
888+ include pairs of samples that are in the provided neighbors_val.
889+ agreement_batch: An array of shape (batch_size,) with binary values,
890+ where each row represents whether the labels of the corresponding
891+ neighbor pair agree (1.0) or not (0.0).
892+ """
884893 neighbors_batch = np .empty (shape = (self .batch_size , 2 ), dtype = np .int32 )
885894 agreement_batch = np .empty (shape = (self .batch_size ,), dtype = np .float32 )
886895 while True :
@@ -905,7 +914,23 @@ def _pair_iterator(self, labeled_nodes, data, ratio_pos_neg=None):
905914 yield neighbors_batch , agreement_batch
906915
907916 def _select_val_samples (self , labeled_samples , ratio_val ):
908- # TODO: add documentation.
917+ """Split the labeled samples into a train and a validation set.
918+
919+ The agreement model is trained using pairs of labeled samples from the train
920+ set, and is evaluated on pairs of labeled samples from the validation set.
921+
922+ Arguments:
923+ labeled_samples:
924+ ratio_val: A number between (0, 1) representing the ratio of all labeled
925+ samples to be set aside for validation.
926+
927+ Returns:
928+ labeled_samples_train: An array containig a subset of the provided
929+ labeled_samples which will be used for training.
930+ labeled_samples_val: An array containig a subset of the provided
931+ labeled_samples which will be used for validation. The train and
932+ validation indices are non-overlapping.
933+ """
909934 num_labeled_samples = labeled_samples .shape [0 ]
910935 num_labeled_samples_val = int (num_labeled_samples * ratio_val )
911936 self .rng .shuffle (labeled_samples )
0 commit comments