Skip to content

Commit 7a97f52

Browse files
ppham27tensorflow-copybara
authored andcommitted
Make sequential model inputs with adversarial regularization more flexible
tensorflow/tensorflow@eea7bbc tries to map dictionary keys back to an input. In sequential models the input names are arbitrary, so we should just flatten to a list. PiperOrigin-RevId: 301224247
1 parent c6e35ab commit 7a97f52

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,11 +601,21 @@ def _split_inputs(self, inputs):
601601
for key, value in six.iteritems(inputs)
602602
if key not in non_feature_keys
603603
}
604+
# In some cases, Sequential models are automatically compiled to graph
605+
# networks with automatically generated input names. In this case, the user
606+
# isn't expected to know those names, so we just flatten the inputs. But the
607+
# input names are sometimes meaningful (e.g. DenseFeatures layer). We check
608+
# if there is any intersection between the user-provided names and model's
609+
# input names. If there is, we assume the names are meaningful and preserve
610+
# the dictionary.
611+
if (isinstance(self.base_model, tf.keras.Sequential) and
612+
not (set(getattr(self.base_model, 'input_names', []))
613+
& set(inputs.keys()))):
614+
inputs = tf.nest.flatten(inputs)
604615
return inputs, labels, sample_weights
605616

606617
def _call_base_model(self, inputs, **kwargs):
607-
if (self.base_model._is_graph_network and # pylint: disable=protected-access
608-
not isinstance(self.base_model, tf.keras.Sequential)):
618+
if isinstance(inputs, dict) and self.base_model._is_graph_network: # pylint: disable=protected-access
609619
base_input_names = getattr(self.base_model, 'input_names', None)
610620
if base_input_names:
611621
# Converts input dictionary to a list so it conforms with the model's
@@ -616,7 +626,7 @@ def _call_base_model(self, inputs, **kwargs):
616626
def _forward_pass(self, inputs, labels, sample_weights, base_model_kwargs):
617627
"""Runs the usual forward pass to compute outputs, loss, and metrics."""
618628
with tf.GradientTape() as tape:
619-
tape.watch(list(inputs.values()))
629+
tape.watch(tf.nest.flatten(inputs))
620630
outputs = self._call_base_model(inputs, **base_model_kwargs)
621631
# If the base_model is a subclassed model, its output_names are not
622632
# available before its first call. If it is a dynamic subclassed model,

0 commit comments

Comments
 (0)