Skip to content

Commit 3674768

Browse files
csferngtensorflow-copybara
authored andcommitted
Add a constructor parameter for passing label features to base model.
Fixes #37. PiperOrigin-RevId: 305126834
1 parent 50e1098 commit 3674768

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def __init__(self,
460460
base_model,
461461
label_keys=('label',),
462462
sample_weight_key=None,
463-
adv_config=None):
463+
adv_config=None,
464+
base_with_labels_in_features=False):
464465
"""Constructor of `AdversarialRegularization` class.
465466
466467
Args:
@@ -474,13 +475,22 @@ def __init__(self,
474475
the weight is 1.0 for each input example.
475476
adv_config: Instance of `nsl.configs.AdvRegConfig` for configuring
476477
adversarial regularization.
478+
base_with_labels_in_features: A Boolean value indicating whether the base
479+
model expects label features as input. This option is effective only
480+
when the base model is a subclassed Keras model. (For functional and
481+
Sequential models, the expected inputs can be inferred from the model
482+
itself.) If set to true, the base model will be called with an input
483+
dictionary including label and sample-weight features. If set to false,
484+
label and sample-weight features will not present in base model's input
485+
dictionary.
477486
"""
478487
super(AdversarialRegularization,
479488
self).__init__(name='AdversarialRegularization')
480489
self.base_model = base_model
481490
self.label_keys = label_keys
482491
self.sample_weight_key = sample_weight_key
483492
self.adv_config = adv_config or nsl_configs.AdvRegConfig()
493+
self._base_with_labels_in_features = base_with_labels_in_features
484494

485495
def compile(self,
486496
optimizer,
@@ -585,42 +595,45 @@ def _compute_total_loss(self, labels, outputs, sample_weights=None):
585595
outputs, sample_weights)
586596
return loss
587597

588-
def _split_inputs(self, inputs):
598+
def _extract_labels_and_weights(self, inputs):
589599
sample_weights = inputs.get(self.sample_weight_key, None)
590600
if sample_weights is not None:
591601
sample_weights = tf.stop_gradient(sample_weights)
592602
# Labels shouldn't be perturbed when generating adversarial examples.
593603
labels = [
594604
tf.stop_gradient(inputs[label_key]) for label_key in self.label_keys
595605
]
596-
# Removes labels and sample weights from the input dictionary, since they
597-
# are only used in this class and base model does not need them as inputs.
606+
return labels, sample_weights
607+
608+
def _remove_labels_and_weights(self, inputs):
598609
non_feature_keys = set(self.label_keys).union([self.sample_weight_key])
599-
inputs = {
610+
return {
600611
key: value
601612
for key, value in six.iteritems(inputs)
602613
if key not in non_feature_keys
603614
}
604-
# In some cases, Sequential models are automatically compiled to graph
605-
# networks with automatically generated input names. In this case, the user
606-
# isn't expected to know those names, so we just flatten the inputs. But the
607-
# input names are sometimes meaningful (e.g. DenseFeatures layer). We check
608-
# if there is any intersection between the user-provided names and model's
609-
# input names. If there is, we assume the names are meaningful and preserve
610-
# the dictionary.
611-
if (isinstance(self.base_model, tf.keras.Sequential) and
612-
not (set(getattr(self.base_model, 'input_names', []))
613-
& set(inputs.keys()))):
614-
inputs = tf.nest.flatten(inputs)
615-
return inputs, labels, sample_weights
616615

617616
def _call_base_model(self, inputs, **kwargs):
618-
if isinstance(inputs, dict) and self.base_model._is_graph_network: # pylint: disable=protected-access
619-
base_input_names = getattr(self.base_model, 'input_names', None)
617+
base_input_names = getattr(self.base_model, 'input_names', [])
618+
if (isinstance(self.base_model, tf.keras.Sequential) and
619+
not set(base_input_names) & set(inputs.keys())):
620+
# In some cases, Sequential models are automatically compiled to graph
621+
# networks with automatically generated input names. In this case, the
622+
# user isn't expected to know those names, so we just flatten the inputs.
623+
# But the input names are sometimes meaningful (e.g. DenseFeatures layer).
624+
# We check if there is any intersection between the user-provided names
625+
# and model's input names. If there is, we assume the names are meaningful
626+
# and do name-based lookup in the next branch.
627+
inputs = tf.nest.flatten(self._remove_labels_and_weights(inputs))
628+
elif self.base_model._is_graph_network: # pylint: disable=protected-access
620629
if base_input_names:
621630
# Converts input dictionary to a list so it conforms with the model's
622631
# expected input.
623632
inputs = [inputs[name] for name in base_input_names]
633+
elif not self._base_with_labels_in_features:
634+
# Removes labels and sample weights from the input dictionary, since they
635+
# are only used in this class and base model does not need them as inputs.
636+
inputs = self._remove_labels_and_weights(inputs)
624637
return self.base_model(inputs, **kwargs)
625638

626639
def _forward_pass(self, inputs, labels, sample_weights, base_model_kwargs):
@@ -647,7 +660,7 @@ def call(self, inputs, **kwargs):
647660
raise ValueError('Labels are not in the input. For predicting examples '
648661
'without labels, please use the base model instead.')
649662

650-
inputs, labels, sample_weights = self._split_inputs(inputs)
663+
labels, sample_weights = self._extract_labels_and_weights(inputs)
651664
outputs, labeled_loss, metrics, tape = self._forward_pass(
652665
inputs, labels, sample_weights, kwargs)
653666
self.add_loss(labeled_loss)
@@ -690,8 +703,9 @@ def perturb_on_batch(self, x, **config_kwargs):
690703
A dictionary of NumPy arrays, `SparseTensor`, or `RaggedTensor` objects of
691704
the generated adversarial examples.
692705
"""
693-
x = tf.nest.map_structure(tf.convert_to_tensor, x, expand_composites=True)
694-
inputs, labels, sample_weights = self._split_inputs(x)
706+
inputs = tf.nest.map_structure(
707+
tf.convert_to_tensor, x, expand_composites=True)
708+
labels, sample_weights = self._extract_labels_and_weights(inputs)
695709
_, labeled_loss, _, tape = self._forward_pass(inputs, labels,
696710
sample_weights,
697711
{'training': False})

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def evaluate(self, *args, **kwargs):
117117
# is not created until the first call to the model, so the initialization
118118
# is not captured in the global_variables_initializer above.
119119
with tf.keras.backend.get_session().as_default():
120-
return super(AdversarialLossTest, self).evaluate(
121-
*args, **kwargs)
120+
return super(AdversarialLossTest, self).evaluate(*args, **kwargs)
122121
else:
123-
return super(AdversarialLossTest, self).evaluate(
124-
*args, **kwargs)
122+
return super(AdversarialLossTest, self).evaluate(*args, **kwargs)
125123

126124
@parameterized.named_parameters([
127125
('sequential', build_linear_keras_sequential_model),
126+
('sequential_no_input_layer',
127+
build_linear_keras_sequential_model_no_input_layer),
128128
('functional', build_linear_keras_functional_model),
129129
('subclassed', build_linear_keras_subclassed_model),
130130
])
@@ -511,6 +511,37 @@ def test_train_with_2_inputs(self, name1, name2):
511511
self.assertAllClose(w1_new, tf.keras.backend.get_value(dense1.weights[0]))
512512
self.assertAllClose(w2_new, tf.keras.backend.get_value(dense2.weights[0]))
513513

514+
def test_train_subclassed_base_model_with_label_input(self):
515+
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
516+
517+
inputs = {'feature': tf.constant(x0), 'label': tf.constant(y0)}
518+
519+
class BaseModel(tf.keras.Model):
520+
521+
def __init__(self):
522+
super(BaseModel, self).__init__()
523+
self.dense = tf.keras.layers.Dense(
524+
w.shape[-1],
525+
use_bias=False,
526+
kernel_initializer=tf.keras.initializers.Constant(w))
527+
self.seen_input_keys = set()
528+
529+
def call(self, inputs):
530+
self.seen_input_keys |= set(inputs.keys())
531+
return self.dense(inputs['feature'])
532+
533+
model = BaseModel()
534+
adv_model = adversarial_regularization.AdversarialRegularization(
535+
model,
536+
label_keys=['label'],
537+
adv_config=adv_config,
538+
base_with_labels_in_features=True)
539+
adv_model.compile(
540+
optimizer=tf.keras.optimizers.SGD(lr), loss='MSE', metrics=['mae'])
541+
adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
542+
543+
self.assertIn('label', model.seen_input_keys)
544+
514545
def test_evaluate_binary_classification_metrics(self):
515546
# multi-label binary classification model
516547
w = np.array([[4.0, 1.0, -5.0], [-3.0, 1.0, 2.0]])
@@ -564,10 +595,17 @@ def test_evaluate_classification_metrics(self):
564595
self.assertAllClose(cross_entropy,
565596
results['sparse_categorical_crossentropy'])
566597

567-
def test_perturb_on_batch(self):
598+
@parameterized.named_parameters([
599+
('sequential', build_linear_keras_sequential_model),
600+
('sequential_no_input_layer',
601+
build_linear_keras_sequential_model_no_input_layer),
602+
('functional', build_linear_keras_functional_model),
603+
('subclassed', build_linear_keras_subclassed_model),
604+
])
605+
def test_perturb_on_batch(self, model_fn):
568606
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
569607
inputs = {'feature': x0, 'label': y0}
570-
model = build_linear_keras_functional_model(input_shape=(2,), weights=w)
608+
model = model_fn(input_shape=(2,), weights=w)
571609
adv_model = adversarial_regularization.AdversarialRegularization(
572610
model, label_keys=['label'], adv_config=adv_config)
573611
adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss=['MSE'])

0 commit comments

Comments
 (0)