Skip to content

Commit eef1760

Browse files
committed
Removed unused function.
1 parent 8a90308 commit eef1760

File tree

1 file changed

+1
-98
lines changed

1 file changed

+1
-98
lines changed

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(self,
127127
weight_decay_schedule=None,
128128
num_pairs_eval_random=1000,
129129
agree_by_default=False,
130-
percent_val=0.2,
130+
percent_val=0.1,
131131
max_num_samples_val=10000,
132132
seed=None,
133133
lr_decay_steps=None,
@@ -485,103 +485,6 @@ def _eval_validation(self, data, labeled_nodes_val, ratio_pos_to_neg,
485485
cummulative_val_acc /= samples_seen
486486
return cummulative_val_acc
487487

488-
def _train_iterator(self, labeled_samples, neighbors_val, data,
489-
ratio_pos_to_neg=None):
490-
"""An iterator over pairs of samples for training the agreement model.
491-
492-
Provides batches of node pairs, including their features and the agreement
493-
label (i.e. whether their labels agree). A set of validation pairs
494-
is also provided to make sure those samples are not included in train.
495-
496-
Arguments:
497-
labeled_samples: An array of integers representing the indices of the
498-
labeled nodes.
499-
neighbors_val: An array of shape (num_samples, 2), where each row
500-
represents a pair of sample indices used for validation.
501-
data: A Dataset object used to provided the labels of the labeled samples.
502-
ratio_pos_to_neg: A float representing the ratio of positive to negative
503-
samples in the training set. If this is provided, the train iterator
504-
will do rejection sampling based on this ratio to keep the training
505-
data balanced. If None, we sample uniformly.
506-
Yields:
507-
neighbors_batch: An array of shape (batch_size, 2), where each row
508-
represents a pair of sample indices used for training. It will not
509-
include pairs of samples that are in the provided neighbors_val.
510-
agreement_batch: An array of shape (batch_size,) with binary values,
511-
where each row represents whether the labels of the corresponding
512-
neighbor pair agree (1.0) or not (0.0).
513-
"""
514-
neighbors_val = set([(pair[0], pair[1]) if pair[0] < pair[1] else
515-
(pair[1], pair[0]) for pair in neighbors_val])
516-
neighbors_batch = np.empty(shape=(self.batch_size, 2), dtype=np.int32)
517-
agreement_batch = np.empty(shape=(self.batch_size,), dtype=np.float32)
518-
# TODO(otilastr): remove this. Temporary while fixing something.
519-
# For sampling random pairs of samples very fast, we create two buffers,
520-
# one containing elements for the left side of the pair, the other for the
521-
# right side, and we go through them in parallel.
522-
# buffer_left = np.copy(labeled_samples)
523-
# buffer_right = np.copy(labeled_samples)
524-
# idx_buffer = np.inf
525-
# num_labeled = len(labeled_samples)
526-
# while True:
527-
# num_added = 0
528-
# while num_added < self.batch_size:
529-
# if idx_buffer >= num_labeled:
530-
# idx_buffer = 0
531-
# self.rng.shuffle(buffer_left)
532-
# self.rng.shuffle(buffer_right)
533-
# pair = (buffer_left[idx_buffer], buffer_right[idx_buffer])
534-
# idx_buffer += 1
535-
# if pair[0] == pair[1]:
536-
# continue
537-
# ordered_pair = ((pair[0], pair[1]) if pair[0] < pair[1] else
538-
# (pair[1], pair[0]))
539-
# if ordered_pair in neighbors_val:
540-
# continue
541-
# agreement = data.get_labels(pair[0]) == data.get_labels(pair[1])
542-
# if ratio_pos_to_neg is not None:
543-
# # To keep the positive and negatives balanced, do rejection sampling
544-
# # according to their ratio.
545-
# if ratio_pos_to_neg < 1 and not agreement:
546-
# # Reject a negative sample with some probability.
547-
# random_number = self.rng.rand(1)[0]
548-
# if random_number > ratio_pos_to_neg:
549-
# continue
550-
# elif ratio_pos_to_neg > 1 and agreement:
551-
# # Reject a positive sample with some probability.
552-
# random_number = self.rng.random()
553-
# if random_number > 1.0 / ratio_pos_to_neg:
554-
# continue
555-
# neighbors_batch[num_added][0] = pair[0]
556-
# neighbors_batch[num_added][1] = pair[1]
557-
# agreement_batch[num_added] = agreement
558-
# num_added += 1
559-
# yield neighbors_batch, agreement_batch
560-
while True:
561-
num_added = 0
562-
while num_added < self.batch_size:
563-
pair = self.rng.choice(labeled_samples, 2)
564-
ordered_pair = (pair[0], pair[1]) if pair[0] < pair[1] else \
565-
(pair[1], pair[0])
566-
if ordered_pair in neighbors_val:
567-
continue
568-
agreement = data.get_labels(pair[0]) == data.get_labels(pair[1])
569-
if ratio_pos_to_neg is not None:
570-
# Keep positives and negatives balanced.
571-
if ratio_pos_to_neg < 1 and not agreement:
572-
random_number = self.rng.rand(1)[0]
573-
if random_number > ratio_pos_to_neg:
574-
continue
575-
elif ratio_pos_to_neg > 1 and agreement:
576-
random_number = self.rng.rand(1)[0]
577-
if random_number > 1.0 / ratio_pos_to_neg:
578-
continue
579-
neighbors_batch[num_added][0] = pair[0]
580-
neighbors_batch[num_added][1] = pair[1]
581-
agreement_batch[num_added] = agreement
582-
num_added += 1
583-
yield neighbors_batch, agreement_batch
584-
585488
def _select_val_set(self, labeled_samples, num_samples, data,
586489
ratio_pos_to_neg=None):
587490
"""Select a validation set for the agreement model.

0 commit comments

Comments
 (0)