Skip to content

Commit 989c055

Browse files
arjungtensorflow-copybara
authored andcommitted
Restrict the expectation of neighbor features to the training mode. This updates both the Keras and the Estimator wrappers in NSL.
PiperOrigin-RevId: 285269590
1 parent 4014a76 commit 989c055

File tree

7 files changed

+288
-43
lines changed

7 files changed

+288
-43
lines changed

neural_structured_learning/estimator/adversarial_regularization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
8080
`num_ps_replicas`, or `model_dir`. Unused currently.
8181
8282
Returns:
83-
A `tf.EstimatorSpec` whose loss incorporates graph-based regularization.
83+
A `tf.estimator.EstimatorSpec` with adversarial regularization.
8484
"""
8585

8686
# Uses the same variable scope for calculating the original objective and

neural_structured_learning/estimator/graph_regularization.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def graph_reg_model_fn(features, labels, mode, params=None, config=None):
7777
as `num_ps_replicas`, or `model_dir`. Unused currently.
7878
7979
Returns:
80-
A `tf.EstimatorSpec` whose loss incorporates graph-based regularization.
80+
A `tf.estimator.EstimatorSpec` with graph regularization.
8181
"""
8282

8383
# Uses the same variable scope for calculating the original objective and
@@ -86,10 +86,19 @@ def graph_reg_model_fn(features, labels, mode, params=None, config=None):
8686
tf.compat.v1.get_variable_scope(),
8787
reuse=tf.compat.v1.AUTO_REUSE,
8888
auxiliary_name_scope=False):
89-
# Extract sample features, neighbor features, and neighbor weights.
90-
sample_features, nbr_features, nbr_weights = (
91-
utils.unpack_neighbor_features(features,
92-
graph_reg_config.neighbor_config))
89+
nbr_features = dict()
90+
nbr_weights = None
91+
if mode == tf.estimator.ModeKeys.TRAIN:
92+
# Extract sample features, neighbor features, and neighbor weights if we
93+
# are in training mode.
94+
sample_features, nbr_features, nbr_weights = (
95+
utils.unpack_neighbor_features(features,
96+
graph_reg_config.neighbor_config))
97+
else:
98+
# Otherwise, we strip out all neighbor features and use just the
99+
# sample's features.
100+
sample_features = utils.strip_neighbor_features(
101+
features, graph_reg_config.neighbor_config)
93102

94103
# If no 'params' is passed, then it is possible for base_model_fn not to
95104
# accept a 'params' argument. See documentation for tf.estimator.Estimator

neural_structured_learning/estimator/graph_regularization_test.py

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,14 @@ def embedding_fn(features, unused_mode):
174174
"""
175175

176176
input_fn = single_example_input_fn(
177-
example, input_shape=[1], max_neighbors=1)
177+
example, input_shape=[1], max_neighbors=0)
178178
predictions = graph_reg_est.predict(input_fn=input_fn)
179179
predicted_scores = [x['predictions'] for x in predictions]
180180
self.assertAllClose([[3.0]], predicted_scores)
181181

182-
def train_and_check_params(self, example, max_neighbors, weight, bias,
183-
expected_grad_from_weight,
184-
expected_grad_from_bias):
182+
def _train_and_check_params(self, example, max_neighbors, weight, bias,
183+
expected_grad_from_weight,
184+
expected_grad_from_bias):
185185
"""Runs training for one step and verifies gradient-based updates."""
186186

187187
def embedding_fn(features, unused_mode):
@@ -261,7 +261,8 @@ def test_graph_reg_wrapper_one_neighbor_with_training(self):
261261
# which includes the supervised loss as well as the graph loss.
262262
orig_pred = np.dot(x0, weight) + bias # [9.0]
263263

264-
# Based on the implementation of embedding_fn inside train_and_check_params.
264+
# Based on the implementation of embedding_fn inside
265+
# _train_and_check_params.
265266
x0_embedding = np.dot(x0, weight)
266267
neighbor0_embedding = np.dot(neighbor0, weight)
267268

@@ -271,8 +272,8 @@ def test_graph_reg_wrapper_one_neighbor_with_training(self):
271272
neighbor0).T # [[2.5], [1.5]]
272273
orig_grad_b = 2 * (orig_pred - y0).reshape((1,)) # [2.0]
273274

274-
self.train_and_check_params(example, 1, weight, bias, orig_grad_w,
275-
orig_grad_b)
275+
self._train_and_check_params(example, 1, weight, bias, orig_grad_w,
276+
orig_grad_b)
276277

277278
@test_util.run_v1_only('Requires tf.get_variable')
278279
def test_graph_reg_wrapper_two_neighbors_with_training(self):
@@ -318,7 +319,8 @@ def test_graph_reg_wrapper_two_neighbors_with_training(self):
318319
# which includes the supervised loss as well as the graph loss.
319320
orig_pred = np.dot(x0, weight) + bias # [9.0]
320321

321-
# Based on the implementation of embedding_fn inside train_and_check_params.
322+
# Based on the implementation of embedding_fn inside
323+
# _train_and_check_params.
322324
x0_embedding = np.dot(x0, weight)
323325
neighbor0_embedding = np.dot(neighbor0, weight)
324326
neighbor1_embedding = np.dot(neighbor1, weight)
@@ -338,8 +340,101 @@ def test_graph_reg_wrapper_two_neighbors_with_training(self):
338340
orig_grad_w = grad_w_supervised_loss + grad_w_graph_loss
339341
orig_grad_b = 2 * (orig_pred - y0).reshape((1,)) # [2.0]
340342

341-
self.train_and_check_params(example, 2, weight, bias, orig_grad_w,
342-
orig_grad_b)
343+
self._train_and_check_params(example, 2, weight, bias, orig_grad_w,
344+
orig_grad_b)
345+
346+
def _train_and_check_eval_results(self, train_example, test_example,
347+
max_neighbors, weight, bias):
348+
"""Verifies evaluation results for the graph-regularized model."""
349+
350+
def embedding_fn(features, unused_mode):
351+
# Computes y = w*x
352+
with tf.variable_scope(
353+
tf.get_variable_scope(),
354+
reuse=tf.AUTO_REUSE,
355+
auxiliary_name_scope=False):
356+
weight_tensor = tf.reshape(
357+
tf.get_variable(
358+
WEIGHT_VARIABLE,
359+
shape=[2, 1],
360+
partitioner=tf.fixed_size_partitioner(1)),
361+
shape=[-1, 2])
362+
363+
x_tensor = tf.reshape(features[FEATURE_NAME], shape=[-1, 2])
364+
return tf.reduce_sum(
365+
tf.multiply(weight_tensor, x_tensor), 1, keep_dims=True)
366+
367+
def optimizer_fn():
368+
return tf.train.GradientDescentOptimizer(LEARNING_RATE)
369+
370+
base_est = self.build_linear_regressor(
371+
weight=weight, weight_shape=[2, 1], bias=bias, bias_shape=[1])
372+
373+
graph_reg_config = nsl_configs.make_graph_reg_config(
374+
max_neighbors=max_neighbors, multiplier=1)
375+
graph_reg_est = nsl_estimator.add_graph_regularization(
376+
base_est, embedding_fn, optimizer_fn, graph_reg_config=graph_reg_config)
377+
378+
train_input_fn = single_example_input_fn(
379+
train_example, input_shape=[2], max_neighbors=max_neighbors)
380+
graph_reg_est.train(input_fn=train_input_fn, steps=1)
381+
382+
# Evaluating the graph-regularized model should yield the same results
383+
# as evaluating the base model because model paramters are shared.
384+
eval_input_fn = single_example_input_fn(
385+
test_example, input_shape=[2], max_neighbors=0)
386+
graph_eval_results = graph_reg_est.evaluate(input_fn=eval_input_fn)
387+
base_eval_results = base_est.evaluate(input_fn=eval_input_fn)
388+
self.assertAllClose(base_eval_results, graph_eval_results)
389+
390+
@test_util.run_v1_only('Requires tf.get_variable')
391+
def test_graph_reg_model_evaluate(self):
392+
weight = np.array([[4.0], [-3.0]])
393+
bias = np.array([0.0], dtype=np.float32)
394+
395+
train_example = """
396+
features {
397+
feature {
398+
key: "x"
399+
value: { float_list { value: [ 2.0, 3.0 ] } }
400+
}
401+
feature {
402+
key: "NL_nbr_0_x"
403+
value: { float_list { value: [ 2.5, 3.0 ] } }
404+
}
405+
feature {
406+
key: "NL_nbr_0_weight"
407+
value: { float_list { value: 1.0 } }
408+
}
409+
feature {
410+
key: "NL_nbr_1_x"
411+
value: { float_list { value: [ 2.0, 2.0 ] } }
412+
}
413+
feature {
414+
key: "NL_nbr_1_weight"
415+
value: { float_list { value: 1.0 } }
416+
}
417+
feature {
418+
key: "y"
419+
value: { float_list { value: 0.0 } }
420+
}
421+
}
422+
"""
423+
424+
test_example = """
425+
features {
426+
feature {
427+
key: "x"
428+
value: { float_list { value: [ 4.0, 2.0 ] } }
429+
}
430+
feature {
431+
key: "y"
432+
value: { float_list { value: 4.0 } }
433+
}
434+
}
435+
"""
436+
self._train_and_check_eval_results(
437+
train_example, test_example, max_neighbors=2, weight=weight, bias=bias)
343438

344439

345440
if __name__ == '__main__':

neural_structured_learning/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,14 @@ py_library(
8181
deps = [
8282
"//neural_structured_learning/configs",
8383
"//neural_structured_learning/keras/layers",
84+
"//neural_structured_learning/lib",
8485
# package tensorflow
8586
],
8687
)
8788

8889
py_test(
8990
name = "graph_regularization_test",
91+
timeout = "long",
9092
srcs = ["graph_regularization_test.py"],
9193
srcs_version = "PY2AND3",
9294
deps = [

neural_structured_learning/keras/graph_regularization.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ def compile(self, *args, **kwargs):
8484

8585
compile.__doc__ = tf.keras.Model.compile.__doc__
8686

87+
# Override the evaluate and the predict methods so that we can use the base
88+
# model for evaluation/prediction rather than the graph-regularized model.
89+
# This is because once the graph-regularized Keras model is built, it expects
90+
# neighbor features as input for all modes and not just for training.
91+
def evaluate(self, *args, **kwargs):
92+
return self.base_model.evaluate(*args, **kwargs)
93+
94+
evaluate.__doc__ = tf.keras.Model.evaluate.__doc__
95+
96+
def predict(self, *args, **kwargs):
97+
return self.base_model.predict(*args, **kwargs)
98+
99+
predict.__doc__ = tf.keras.Model.predict.__doc__
100+
87101
def call(self, inputs, training=False, **kwargs):
88102
"""Incorporates graph regularization into the loss of `base_model`.
89103
@@ -99,30 +113,24 @@ def call(self, inputs, training=False, **kwargs):
99113
Returns:
100114
The output tensors for the wrapped graph-regularized model.
101115
"""
102-
sample_features, nbr_features, nbr_weights = self.nbr_features_layer(inputs)
116+
# Invoke the call() function of the neighbor features layer directly instead
117+
# of invoking it as a callable to avoid Keras from wrapping placeholder
118+
# tensors with the tf.identity() op.
119+
sample_features, nbr_features, nbr_weights = self.nbr_features_layer.call(
120+
inputs)
103121
base_output = self.base_model(sample_features, training=training, **kwargs)
104122

123+
# For evaluation and prediction, we use the base model. So, this overridden
124+
# call function will get invoked only for training.
105125
has_nbr_inputs = nbr_weights is not None and nbr_features
106-
107-
# 'training' is a boolean or boolean tensor. So, we have to use the tf.cond
108-
# op to be able to write conditional code based on its value.
109-
110-
def graph_loss_with_regularization():
111-
if (has_nbr_inputs and self.graph_reg_config.multiplier > 0):
112-
# Use logits for regularization.
113-
sample_logits = base_output
114-
nbr_logits = self.base_model(nbr_features, training=training, **kwargs)
115-
return self.regularizer(
116-
sources=sample_logits, targets=nbr_logits, weights=nbr_weights)
117-
else:
118-
return tf.constant(0, dtype=tf.float32)
119-
120-
def graph_loss_without_regularization():
121-
return tf.constant(0, dtype=tf.float32)
122-
123-
graph_loss = tf.cond(
124-
tf.equal(training, tf.constant(True)), graph_loss_with_regularization,
125-
graph_loss_without_regularization)
126+
if (has_nbr_inputs and self.graph_reg_config.multiplier > 0):
127+
# Use logits for regularization.
128+
sample_logits = base_output
129+
nbr_logits = self.base_model(nbr_features, training=training, **kwargs)
130+
graph_loss = self.regularizer(
131+
sources=sample_logits, targets=nbr_logits, weights=nbr_weights)
132+
else:
133+
graph_loss = tf.constant(0, dtype=tf.float32)
126134

127135
# Note that add_metric() cannot be invoked in a control flow branch.
128136
self.add_metric(graph_loss, name='graph_loss', aggregation='mean')

0 commit comments

Comments
 (0)