Skip to content

Commit 6e69973

Browse files
csferngtensorflow-copybara
authored andcommitted
Convert input dictionary to a list for functional Keras models.
Functional Keras models may expect their input features to be in a specific order, which may be different from the alphabetic order used for serializing input dictionaries. Keras `Model` class handles the different ordering by performing a name lookup before executing the model's forward pass. However, the name lookup is only performed when the model is called via high-level interfaces like `model.fit()`, but not when the model is called directly like `model(input)`. Since `nsl.keras.AdversarialRegularization` always calls its base model directly, this creates an interface discrepancy. For example, ``` input = {'a': ..., 'b': ...} model = tf.keras.Model( [tf.keras.Input(..., name='b'), tf.keras.Input(..., name='a')], ...) adv_model = nsl.keras.AdversarialRegularization(model) ... # Compiles both models model.fit(input) # works adv_model.fit(input) # error ``` This fix does the name lookup before calling the base model if the base model is a functional model. Sequential models are excluded because their feature name may not be specified. Subclassed Keras models are also excluded because some subclassed models actually expect dictionary-style input instead of a list. Fixes #27 PiperOrigin-RevId: 289938495
1 parent bfab889 commit 6e69973

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ def _compute_total_loss(self, labels, outputs, sample_weights=None):
585585

586586
def _split_inputs(self, inputs):
587587
sample_weights = inputs.get(self.sample_weight_key, None)
588+
if sample_weights is not None:
589+
sample_weights = tf.stop_gradient(sample_weights)
588590
# Labels shouldn't be perturbed when generating adversarial examples.
589591
labels = [
590592
tf.stop_gradient(inputs[label_key]) for label_key in self.label_keys
@@ -599,11 +601,21 @@ def _split_inputs(self, inputs):
599601
}
600602
return inputs, labels, sample_weights
601603

604+
def _call_base_model(self, inputs, **kwargs):
605+
if (self.base_model._is_graph_network and # pylint: disable=protected-access
606+
not isinstance(self.base_model, tf.keras.Sequential)):
607+
base_input_names = getattr(self.base_model, 'input_names', None)
608+
if base_input_names:
609+
# Converts input dictionary to a list so it conforms with the model's
610+
# expected input.
611+
inputs = [inputs[name] for name in base_input_names]
612+
return self.base_model(inputs, **kwargs)
613+
602614
def _forward_pass(self, inputs, labels, sample_weights, base_model_kwargs):
603615
"""Runs the usual forward pass to compute outputs, loss, and metrics."""
604616
with tf.GradientTape() as tape:
605617
tape.watch(list(inputs.values()))
606-
outputs = self.base_model(inputs, **base_model_kwargs)
618+
outputs = self._call_base_model(inputs, **base_model_kwargs)
607619
# If the base_model is a subclassed model, its output_names are not
608620
# available before its first call. If it is a dynamic subclassed model,
609621
# its output_names are not available even after its first call, so we
@@ -634,7 +646,7 @@ def call(self, inputs, **kwargs):
634646
adv_loss = adversarial_loss(
635647
inputs,
636648
labels,
637-
self.base_model,
649+
self._call_base_model,
638650
self._compute_total_loss,
639651
sample_weights=sample_weights,
640652
adv_config=self.adv_config,

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def build_linear_keras_sequential_model(input_shape, weights):
3939
return model
4040

4141

42+
def build_linear_keras_sequential_model_no_input_layer(input_shape, weights):
43+
return tf.keras.Sequential([
44+
tf.keras.layers.Dense(
45+
weights.shape[-1],
46+
use_bias=False,
47+
input_shape=input_shape,
48+
kernel_initializer=tf.keras.initializers.Constant(weights)),
49+
])
50+
51+
4252
def build_linear_keras_functional_model(input_shape,
4353
weights,
4454
input_name='feature'):
@@ -276,6 +286,8 @@ def _set_up_linear_regression(self, sample_weight=1.0):
276286

277287
@parameterized.named_parameters([
278288
('sequential', build_linear_keras_sequential_model),
289+
('sequential_no_input_layer',
290+
build_linear_keras_sequential_model_no_input_layer),
279291
('functional', build_linear_keras_functional_model),
280292
('subclassed', build_linear_keras_subclassed_model),
281293
])
@@ -460,6 +472,45 @@ def test_train_with_2_outputs(self):
460472
self.assertAllClose(expected_metric,
461473
history.history['mean_absolute_error_label2'][0])
462474

475+
@parameterized.named_parameters([
476+
('order_1_2', 'first', 'second'),
477+
('order_2_1', 'second', 'first'),
478+
])
479+
def test_train_with_2_inputs(self, name1, name2):
480+
x1, x2 = np.array([[1.]]), np.array([[4., 5.]])
481+
w1, w2 = np.array([[2.]]), np.array([[3.], [6.]])
482+
y = np.array([0.])
483+
inputs = {name1: x1, name2: x2, 'label': y}
484+
lr, adv_step_size = 0.001, 0.1
485+
486+
input1 = tf.keras.Input(shape=(1,), name=name1)
487+
input2 = tf.keras.Input(shape=(2,), name=name2)
488+
dense1 = tf.keras.layers.Dense(
489+
w1.shape[-1],
490+
use_bias=False,
491+
kernel_initializer=tf.keras.initializers.Constant(w1))
492+
dense2 = tf.keras.layers.Dense(
493+
w2.shape[-1],
494+
use_bias=False,
495+
kernel_initializer=tf.keras.initializers.Constant(w2))
496+
output = tf.keras.layers.Add()([dense1(input1), dense2(input2)])
497+
model = tf.keras.Model(inputs=[input1, input2], outputs=output)
498+
499+
adv_config = configs.make_adv_reg_config(
500+
multiplier=1.0, adv_step_size=adv_step_size, adv_grad_norm='l2')
501+
adv_model = adversarial_regularization.AdversarialRegularization(
502+
model, label_keys=['label'], adv_config=adv_config)
503+
adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss='MAE')
504+
adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
505+
506+
# loss = |x1 * w1 + x2 * w2|, gradient(loss, [x1, x2]) = [w1, w2]
507+
w_norm = np.sqrt((np.sum(w1 * w1) + np.sum(w2 * w2)))
508+
x1_adv, x2_adv = x1 + adv_step_size * w1.T / w_norm, x2 + adv_step_size * w2.T / w_norm
509+
# gradient(loss, [w1, w2]) = [x1, x2]
510+
w1_new, w2_new = w1 - lr * (x1 + x1_adv).T, w2 - lr * (x2 + x2_adv).T
511+
self.assertAllClose(w1_new, tf.keras.backend.get_value(dense1.weights[0]))
512+
self.assertAllClose(w2_new, tf.keras.backend.get_value(dense2.weights[0]))
513+
463514
def test_evaluate_binary_classification_metrics(self):
464515
# multi-label binary classification model
465516
w = np.array([[4.0, 1.0, -5.0], [-3.0, 1.0, 2.0]])

0 commit comments

Comments
 (0)