Skip to content

Commit db7fcf2

Browse files
committed
Merge with upstream.
2 parents ddab3ee + 74745fc commit db7fcf2

File tree

8 files changed

+396
-185
lines changed

8 files changed

+396
-185
lines changed

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 67 additions & 52 deletions
Large diffs are not rendered by default.

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,13 @@ class TrainerClassification(Trainer):
6969
summary_step: Integer representing the summary step size.
7070
summary_dir: String representing the path to a directory where to save the
7171
variable summaries.
72-
logging_step: Integer representing the number of iterations after which
73-
we log the loss of the model.
72+
logging_step: Integer representing the number of iterations after which we
73+
log the loss of the model.
7474
eval_step: Integer representing the number of iterations after which we
7575
evaluate the model.
76-
warm_start: Whether the model parameters are initialized at their
77-
best value in the previous cotrain iteration. If False, they are
78-
reinitialized.
79-
gradient_clip=None,
76+
warm_start: Whether the model parameters are initialized at their best value
77+
in the previous cotrain iteration. If False, they are reinitialized.
78+
gradient_clip=None,
8079
abs_loss_chg_tol: A float representing the absolute tolerance for checking
8180
if the training loss has converged. If the difference between the current
8281
loss and previous loss is less than `abs_loss_chg_tol`, we count this
@@ -89,19 +88,19 @@ class TrainerClassification(Trainer):
8988
iterations that pass the convergence criteria before stopping training.
9089
checkpoints_dir: Path to the folder where to store TensorFlow model
9190
checkpoints.
92-
weight_decay: Weight for the weight decay term in the classification
93-
model loss.
91+
weight_decay: Weight for the weight decay term in the classification model
92+
loss.
9493
weight_decay_schedule: Schedule how to adjust the classification weight
9594
decay weight after every cotrain iteration.
9695
penalize_neg_agr: Whether to not only encourage agreement between samples
9796
that the agreement model believes should have the same label, but also
9897
penalize agreement when two samples agree when the agreement model
9998
predicts they should disagree.
100-
use_l2_clssif: Whether to use L2 loss for classification, as opposed to the
101-
whichever loss is specified in the provided model_cls.
10299
first_iter_original: A boolean specifying whether the first cotrain
103100
iteration trains the original classification model (with no agreement
104101
term).
102+
use_l2_clssif: Whether to use L2 loss for classification, as opposed to the
103+
whichever loss is specified in the provided model_cls.
105104
seed: Seed used by all the random number generators in this class.
106105
use_graph: Boolean specifying whether the agreement loss is applied to graph
107106
edges, as opposed to random pairs of samples.
@@ -162,8 +161,9 @@ def __init__(self,
162161
self.gradient_clip = gradient_clip
163162
self.logging_step = logging_step
164163
self.eval_step = eval_step
165-
self.checkpoint_path = (os.path.join(checkpoints_dir, 'classif_best.ckpt')
166-
if checkpoints_dir is not None else None)
164+
self.checkpoint_path = (
165+
os.path.join(checkpoints_dir, 'classif_best.ckpt')
166+
if checkpoints_dir is not None else None)
167167
self.weight_decay_initial = weight_decay
168168
self.weight_decay_schedule = weight_decay_schedule
169169
self.num_pairs_reg = num_pairs_reg
@@ -186,11 +186,11 @@ def __init__(self,
186186
# First obtain the features shape from the dataset, and append a batch_size
187187
# dimension to it (i.e., `None` to allow for variable batch size).
188188
features_shape = [None] + list(data.features_shape)
189-
input_features = tf.placeholder(tf.float32, shape=features_shape,
190-
name='input_features')
189+
input_features = tf.placeholder(
190+
tf.float32, shape=features_shape, name='input_features')
191191
input_labels = tf.placeholder(tf.int64, shape=(None,), name='input_labels')
192-
one_hot_labels = tf.one_hot(input_labels, data.num_classes,
193-
name='input_labels_one_hot')
192+
one_hot_labels = tf.one_hot(
193+
input_labels, data.num_classes, name='input_labels_one_hot')
194194
# Create a placeholder specifying if this is train time.
195195
is_train = tf.placeholder_with_default(False, shape=[], name='is_train')
196196

@@ -201,8 +201,8 @@ def __init__(self,
201201
self.variables = variables
202202
self.reg_params = reg_params
203203
predictions, variables, reg_params = (
204-
self.model.get_predictions_and_params(encoding=encoding,
205-
is_train=is_train))
204+
self.model.get_predictions_and_params(
205+
encoding=encoding, is_train=is_train))
206206
self.variables.update(variables)
207207
self.reg_params.update(reg_params)
208208
normalized_predictions = self.model.normalize_predictions(predictions)
@@ -221,9 +221,10 @@ def __init__(self,
221221
loss_supervised = tf.reduce_sum(loss_supervised, axis=-1)
222222
loss_supervised = tf.reduce_mean(loss_supervised)
223223
else:
224-
loss_supervised = self.model.get_loss(predictions=predictions,
225-
targets=one_hot_labels,
226-
weight_decay=None)
224+
loss_supervised = self.model.get_loss(
225+
predictions=predictions,
226+
targets=one_hot_labels,
227+
weight_decay=None)
227228

228229
# Agreement regularization loss.
229230
loss_agr = self._get_agreement_reg_loss(data, is_train, features_shape)
@@ -280,8 +281,9 @@ def __init__(self,
280281
gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
281282
grads_and_vars = tuple(zip(gradients, variab))
282283
with tf.control_dependencies(
283-
tf.get_collection(tf.GraphKeys.UPDATE_OPS,
284-
scope=tf.get_default_graph().get_name_scope())):
284+
tf.get_collection(
285+
tf.GraphKeys.UPDATE_OPS,
286+
scope=tf.get_default_graph().get_name_scope())):
285287
train_op = self.optimizer.apply_gradients(
286288
grads_and_vars, global_step=self.global_step)
287289

@@ -332,7 +334,7 @@ def _create_weight_decay_var(self, weight_decay_initial,
332334
if weight_decay_schedule is None:
333335
if weight_decay_initial is not None:
334336
weight_decay_var = tf.constant(
335-
weight_decay_initial, dtype=tf.float32, name='weight_decay')
337+
weight_decay_initial, dtype=tf.float32, name='weight_decay')
336338
else:
337339
weight_decay_var = None
338340
elif weight_decay_schedule == 'linear':
@@ -406,32 +408,28 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
406408

407409
with tf.variable_scope('predictions', reuse=True):
408410
encoding, _, _ = self.model.get_encoding_and_params(
409-
inputs=features_ll_right, is_train=is_train,
410-
update_batch_stats=False)
411+
inputs=features_ll_right, is_train=is_train, update_batch_stats=False)
411412
predictions_ll_right, _, _ = self.model.get_predictions_and_params(
412413
encoding=encoding, is_train=is_train)
413414
predictions_ll_right = self.model.normalize_predictions(
414415
predictions_ll_right)
415416

416417
encoding, _, _ = self.model.get_encoding_and_params(
417-
inputs=features_lu_right, is_train=is_train,
418-
update_batch_stats=False)
418+
inputs=features_lu_right, is_train=is_train, update_batch_stats=False)
419419
predictions_lu_right, _, _ = self.model.get_predictions_and_params(
420420
encoding=encoding, is_train=is_train)
421421
predictions_lu_right = self.model.normalize_predictions(
422422
predictions_lu_right)
423423

424424
encoding, _, _ = self.model.get_encoding_and_params(
425-
inputs=features_uu_left, is_train=is_train,
426-
update_batch_stats=False)
425+
inputs=features_uu_left, is_train=is_train, update_batch_stats=False)
427426
predictions_uu_left, _, _ = self.model.get_predictions_and_params(
428427
encoding=encoding, is_train=is_train)
429428
predictions_uu_left = self.model.normalize_predictions(
430429
predictions_uu_left)
431430

432431
encoding, _, _ = self.model.get_encoding_and_params(
433-
inputs=features_uu_right, is_train=is_train,
434-
update_batch_stats=False)
432+
inputs=features_uu_right, is_train=is_train, update_batch_stats=False)
435433
predictions_uu_right, _, _ = self.model.get_predictions_and_params(
436434
encoding=encoding, is_train=is_train)
437435
predictions_uu_right = self.model.normalize_predictions(
@@ -442,8 +440,8 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
442440
# Stop gradients need to be added
443441
# The case where there are no more uu or lu
444442
# edges at the end of training, so the shapes don't match needs fixing.
445-
left = tf.concat(
446-
(labels_ll_left, labels_lu_left, predictions_uu_left), axis=0)
443+
left = tf.concat((labels_ll_left, labels_lu_left, predictions_uu_left),
444+
axis=0)
447445
right = tf.concat(
448446
(predictions_ll_right, predictions_lu_right, predictions_uu_right),
449447
axis=0)
@@ -455,12 +453,16 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
455453
agreement_ll = tf.cast(
456454
tf.equal(labels_ll_left_idx, labels_ll_right_idx), dtype=tf.float32)
457455
_, agreement_lu, _, _ = self.trainer_agr.create_agreement_prediction(
458-
src_features=features_lu_left, tgt_features=features_lu_right,
459-
is_train=is_train, src_indices=indices_lu_left,
456+
src_features=features_lu_left,
457+
tgt_features=features_lu_right,
458+
is_train=is_train,
459+
src_indices=indices_lu_left,
460460
tgt_indices=indices_lu_right)
461461
_, agreement_uu, _, _ = self.trainer_agr.create_agreement_prediction(
462-
src_features=features_uu_left, tgt_features=features_uu_right,
463-
is_train=is_train, src_indices=indices_uu_left,
462+
src_features=features_uu_left,
463+
tgt_features=features_uu_right,
464+
is_train=is_train,
465+
src_indices=indices_uu_left,
464466
tgt_indices=indices_uu_right)
465467
agreement = tf.concat((agreement_ll, agreement_lu, agreement_uu), axis=0)
466468
if self.penalize_neg_agr:
@@ -476,10 +478,10 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
476478
num_ll = tf.shape(predictions_ll_right)[0]
477479
num_lu = tf.shape(predictions_lu_right)[0]
478480
num_uu = tf.shape(predictions_uu_left)[0]
479-
weights = tf.concat((self.reg_weight_ll * tf.ones(num_ll,),
480-
self.reg_weight_lu * tf.ones(num_lu,),
481-
self.reg_weight_uu * tf.ones(num_uu,)),
482-
axis=0)
481+
weights = tf.concat(
482+
(self.reg_weight_ll * tf.ones(num_ll,), self.reg_weight_lu *
483+
tf.ones(num_lu,), self.reg_weight_uu * tf.ones(num_uu,)),
484+
axis=0)
483485

484486
# Scale each distance by its agreement weight and regularzation weight.
485487
loss = tf.reduce_mean(dists * weights * agreement)
@@ -511,8 +513,9 @@ def _construct_feed_dict(self,
511513
input_indices = next(data_iterator)
512514
# Select the labels. Use the true, correct labels, at test time, and the
513515
# self-labeled ones at train time.
514-
labels = (self.data.get_original_labels(input_indices) if split == 'test'
515-
else self.data.get_labels(input_indices))
516+
labels = (
517+
self.data.get_original_labels(input_indices)
518+
if split == 'test' else self.data.get_labels(input_indices))
516519
feed_dict = {
517520
self.input_features: self.data.get_features(input_indices),
518521
self.input_labels: labels,
@@ -586,8 +589,8 @@ def _select_from_pool(indices):
586589
while True:
587590
indices_src, features_src, labels_src = _select_from_pool(src_indices)
588591
indices_tgt, features_tgt, labels_tgt = _select_from_pool(tgt_indices)
589-
yield (indices_src, indices_tgt, features_src, features_tgt,
590-
labels_src, labels_tgt)
592+
yield (indices_src, indices_tgt, features_src, features_tgt, labels_src,
593+
labels_tgt)
591594

592595
def edge_iterator(self, data, batch_size, labeling):
593596
"""An iterator over graph edges.
@@ -679,6 +682,7 @@ def train(self, data, session=None, **kwargs):
679682
data: A CotrainDataset object.
680683
session: A TensorFlow session or None.
681684
**kwargs: Other keyword arguments.
685+
682686
Returns:
683687
best_test_acc: A float representing the test accuracy at the iteration
684688
where the validation accuracy is maximum.
@@ -742,11 +746,11 @@ def train(self, data, session=None, **kwargs):
742746
checkpoint_saved = False
743747
while not has_converged:
744748
feed_dict = self._construct_feed_dict(
745-
data_iterator=data_iterator_train,
746-
split='train',
747-
pair_ll_iterator=pair_ll_iterator,
748-
pair_lu_iterator=pair_lu_iterator,
749-
pair_uu_iterator=pair_uu_iterator)
749+
data_iterator=data_iterator_train,
750+
split='train',
751+
pair_ll_iterator=pair_ll_iterator,
752+
pair_lu_iterator=pair_lu_iterator,
753+
pair_uu_iterator=pair_uu_iterator)
750754
if self.enable_summaries and step % self.summary_step == 0:
751755
loss_val, summary, iter_cls_total, _ = session.run(
752756
[self.loss_op, self.summary_op, self.iter_cls_total, self.train_op],
@@ -813,8 +817,10 @@ def predict(self, session, indices, is_train):
813817
input_features = self.data.get_features(batch_indices)
814818
batch_predictions = session.run(
815819
self.normalized_predictions,
816-
feed_dict={self.input_features: input_features,
817-
self.is_train:is_train})
820+
feed_dict={
821+
self.input_features: input_features,
822+
self.is_train: is_train
823+
})
818824
predictions.append(batch_predictions)
819825
idx_start = idx_end
820826
if not predictions:

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,11 @@ def _select_samples_to_label(self, data, trainer_cls, session):
349349
# self-labeling them.
350350
indices_unlabeled = data.get_indices_unlabeled()
351351
val_ind = set(data.get_indices_val())
352-
indices_unlabeled = np.asarray([ind for ind in indices_unlabeled
353-
if ind not in val_ind])
352+
indices_unlabeled = np.asarray(
353+
[ind for ind in indices_unlabeled if ind not in val_ind])
354354
predictions = trainer_cls.predict(
355-
session, indices_unlabeled, is_train=False)
355+
session, indices_unlabeled, is_train=False)
356+
356357
# Select most confident nodes. Compute confidence and most confident label,
357358
# which will be used as the new label.
358359
predicted_label = np.argmax(predictions, axis=-1)

neural_structured_learning/tools/BUILD

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ py_library(
2929
srcs = ["__init__.py"],
3030
srcs_version = "PY2AND3",
3131
deps = [
32-
":build_graph_lib",
32+
":graph_builder",
3333
":graph_utils",
3434
":pack_nbrs_lib",
3535
],
@@ -56,25 +56,41 @@ py_test(
5656
)
5757

5858
py_library(
59-
name = "build_graph_lib",
60-
srcs = ["build_graph.py"],
59+
name = "graph_builder",
60+
srcs = ["graph_builder.py"],
6161
srcs_version = "PY2AND3",
6262
deps = [
6363
":graph_utils",
64-
# package absl:app
65-
# package absl/flags
6664
# package absl/logging
6765
# package numpy
6866
# package six
6967
# package tensorflow
7068
],
7169
)
7270

71+
py_test(
72+
name = "graph_builder_test",
73+
srcs = ["graph_builder_test.py"],
74+
srcs_version = "PY2AND3",
75+
deps = [
76+
":graph_builder",
77+
":graph_utils",
78+
# package protobuf,
79+
# package absl/testing:absltest
80+
# package tensorflow
81+
],
82+
)
83+
7384
py_binary(
74-
name = "build_graph",
75-
srcs = ["build_graph.py"],
85+
name = "graph_builder_main",
86+
srcs = ["graph_builder_main.py"],
7687
python_version = "PY3",
77-
deps = [":build_graph_lib"],
88+
deps = [
89+
":graph_builder",
90+
# package absl:app
91+
# package absl/flags
92+
# package tensorflow
93+
],
7894
)
7995

8096
py_library(
@@ -103,6 +119,8 @@ py_binary(
103119
srcs = ["build_docs.py"],
104120
python_version = "PY3",
105121
deps = [
122+
# package absl:app
123+
# package absl/flags
106124
"//neural_structured_learning",
107125
# package tensorflow_docs/api_generator
108126
],

neural_structured_learning/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tools and APIs for preparing data for Neural Structured Learning."""
22

3-
import neural_structured_learning.tools.build_graph
3+
from neural_structured_learning.tools.graph_builder import build_graph
44
from neural_structured_learning.tools.graph_utils import add_edge
55
from neural_structured_learning.tools.graph_utils import add_undirected_edges
66
from neural_structured_learning.tools.graph_utils import read_tsv_graph

0 commit comments

Comments
 (0)