Skip to content

Commit 404c29e

Browse files
committed
Fixing documentation.
1 parent bed82b9 commit 404c29e

File tree

3 files changed

+48
-27
lines changed

3 files changed

+48
-27
lines changed

neural_structured_learning/research/gam/models/cnn.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,11 @@ class ImageCNNAgreement(Model):
4242
output_dim: Integer representing the number of classes.
4343
channels: Integer representing the number of channels in the input images
4444
(e.g., 1 for black and white, 3 for RGB).
45-
aggregation: String representing an aggregation operation that could be
46-
applied to the inputs. See superclass attributes for details.
47-
hidden_prediction: A tuple or list of integers representing the number of
48-
units in each layer of output multilayer percepron. After the inputs are
49-
passed through the convolution layers (and potentially aggregated), they
50-
are passed through a fully connected network with these numbers of hidden
51-
units in each layer.
45+
aggregation: String representing an aggregation operation, that is applied
46+
on the two inputs of the agreement model, after they are encoded through
47+
the convolution layers. See superclass attributes for details.
5248
activation: An activation function to be applied to the outputs of each
53-
fully connected layer.
49+
fully connected layer of the aggregation network.
5450
is_binary_classification: Boolean specifying if this is model for
5551
binary classification. If so, it uses a different loss function and
5652
returns predictions with a single dimension, batch size.

neural_structured_learning/research/gam/models/wide_resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_encoding_and_params(self,
121121
inputs,
122122
is_train,
123123
update_batch_stats=True,
124-
**kwargs):
124+
**unused_kwargs):
125125
"""Creates the model hidden representations and prediction ops.
126126
127127
For this model, the hidden representation is the last layer
@@ -134,7 +134,7 @@ def get_encoding_and_params(self,
134134
this model will be used for training or for test.
135135
update_batch_stats: Boolean specifying whether to update the batch norm
136136
statistics.
137-
**kwargs: Other keyword arguments.
137+
**unused_kwargs: Other unused keyword arguments.
138138
139139
Returns:
140140
encoding: A tensor containing an encoded batch of samples. The first

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -614,16 +614,12 @@ def train(self, data, session=None, **kwargs):
614614
# Compute ratio of positives to negative samples.
615615
labeled_samples_labels = data.get_labels(labeled_samples)
616616
ratio_pos_to_neg = self._compute_ratio_pos_neg(labeled_samples_labels)
617-
# Select a validation set out of all pairs of labeled samples.
618-
# TODO: remove this.
619-
# neighbors_val, agreement_labels_val = self._select_val_set(
620-
# labeled_samples, num_samples_val, data, ratio_pos_to_neg)
621-
# Create a train iterator that potentially excludes the validation samples.
622-
# data_iterator_train = self._train_iterator(
623-
# labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg)
624617

618+
# Split data into train and validation.
625619
labeled_samples_train, labeled_nodes_val = self._select_val_samples(
626620
labeled_samples, self.ratio_val)
621+
622+
# Create an iterator over training data pairs.
627623
data_iterator_train = self._pair_iterator(labeled_samples_train, data,
628624
ratio_pos_neg=ratio_pos_to_neg)
629625

@@ -659,14 +655,6 @@ def train(self, data, session=None, **kwargs):
659655
if num_samples_val == 0:
660656
logging.info('Skipping validation. No validation samples available.')
661657
break
662-
# TODO: remove this.
663-
# data_iterator_val = batch_iterator(
664-
# neighbors_val,
665-
# agreement_labels_val,
666-
# self.batch_size,
667-
# shuffle=False,
668-
# allow_smaller_batch=True,
669-
# repeat=False)
670658
data_iterator_val = self._pair_iterator(labeled_nodes_val, data)
671659
feed_dict_val = self._construct_feed_dict(
672660
data_iterator_val, is_train=False)
@@ -880,7 +868,28 @@ def predict_label_by_agreement(self, session, indices, num_neighbors=100):
880868
return acc
881869

882870
def _pair_iterator(self, labeled_nodes, data, ratio_pos_neg=None):
883-
# TODO: add documentation and rename neighbors to samples.
871+
"""An iterator over pairs of samples for training the agreement model.
872+
873+
Provides batches of node pairs, including their features and the agreement
874+
label (i.e. whether their labels agree).
875+
876+
Arguments:
877+
labeled_nodes: An array of integers representing the indices of the
878+
labeled samples.
879+
data: A Dataset object used to provided the labels of the labeled samples.
880+
ratio_pos_neg: A float representing the ratio of positive to negative
881+
samples in the training set. If this is provided, the train iterator
882+
will do rejection sampling based on this ratio to keep the training
883+
data balanced. If None, we sample uniformly.
884+
885+
Yields:
886+
neighbors_batch: An array of shape (batch_size, 2), where each row
887+
represents a pair of sample indices used for training. It will not
888+
include pairs of samples that are in the provided neighbors_val.
889+
agreement_batch: An array of shape (batch_size,) with binary values,
890+
where each row represents whether the labels of the corresponding
891+
neighbor pair agree (1.0) or not (0.0).
892+
"""
884893
neighbors_batch = np.empty(shape=(self.batch_size, 2), dtype=np.int32)
885894
agreement_batch = np.empty(shape=(self.batch_size,), dtype=np.float32)
886895
while True:
@@ -905,7 +914,23 @@ def _pair_iterator(self, labeled_nodes, data, ratio_pos_neg=None):
905914
yield neighbors_batch, agreement_batch
906915

907916
def _select_val_samples(self, labeled_samples, ratio_val):
908-
# TODO: add documentation.
917+
"""Split the labeled samples into a train and a validation set.
918+
919+
The agreement model is trained using pairs of labeled samples from the train
920+
set, and is evaluated on pairs of labeled samples from the validation set.
921+
922+
Arguments:
923+
labeled_samples:
924+
ratio_val: A number between (0, 1) representing the ratio of all labeled
925+
samples to be set aside for validation.
926+
927+
Returns:
928+
labeled_samples_train: An array containig a subset of the provided
929+
labeled_samples which will be used for training.
930+
labeled_samples_val: An array containig a subset of the provided
931+
labeled_samples which will be used for validation. The train and
932+
validation indices are non-overlapping.
933+
"""
909934
num_labeled_samples = labeled_samples.shape[0]
910935
num_labeled_samples_val = int(num_labeled_samples * ratio_val)
911936
self.rng.shuffle(labeled_samples)

0 commit comments

Comments
 (0)