@@ -460,7 +460,8 @@ def __init__(self,
460460 base_model ,
461461 label_keys = ('label' ,),
462462 sample_weight_key = None ,
463- adv_config = None ):
463+ adv_config = None ,
464+ base_with_labels_in_features = False ):
464465 """Constructor of `AdversarialRegularization` class.
465466
466467 Args:
@@ -474,13 +475,22 @@ def __init__(self,
474475 the weight is 1.0 for each input example.
475476 adv_config: Instance of `nsl.configs.AdvRegConfig` for configuring
476477 adversarial regularization.
478+ base_with_labels_in_features: A Boolean value indicating whether the base
479+ model expects label features as input. This option is effective only
480+ when the base model is a subclassed Keras model. (For functional and
481+ Sequential models, the expected inputs can be inferred from the model
482+ itself.) If set to true, the base model will be called with an input
483+ dictionary including label and sample-weight features. If set to false,
484+ label and sample-weight features will not present in base model's input
485+ dictionary.
477486 """
478487 super (AdversarialRegularization ,
479488 self ).__init__ (name = 'AdversarialRegularization' )
480489 self .base_model = base_model
481490 self .label_keys = label_keys
482491 self .sample_weight_key = sample_weight_key
483492 self .adv_config = adv_config or nsl_configs .AdvRegConfig ()
493+ self ._base_with_labels_in_features = base_with_labels_in_features
484494
485495 def compile (self ,
486496 optimizer ,
@@ -585,42 +595,45 @@ def _compute_total_loss(self, labels, outputs, sample_weights=None):
585595 outputs , sample_weights )
586596 return loss
587597
588- def _split_inputs (self , inputs ):
598+ def _extract_labels_and_weights (self , inputs ):
589599 sample_weights = inputs .get (self .sample_weight_key , None )
590600 if sample_weights is not None :
591601 sample_weights = tf .stop_gradient (sample_weights )
592602 # Labels shouldn't be perturbed when generating adversarial examples.
593603 labels = [
594604 tf .stop_gradient (inputs [label_key ]) for label_key in self .label_keys
595605 ]
596- # Removes labels and sample weights from the input dictionary, since they
597- # are only used in this class and base model does not need them as inputs.
606+ return labels , sample_weights
607+
608+ def _remove_labels_and_weights (self , inputs ):
598609 non_feature_keys = set (self .label_keys ).union ([self .sample_weight_key ])
599- inputs = {
610+ return {
600611 key : value
601612 for key , value in six .iteritems (inputs )
602613 if key not in non_feature_keys
603614 }
604- # In some cases, Sequential models are automatically compiled to graph
605- # networks with automatically generated input names. In this case, the user
606- # isn't expected to know those names, so we just flatten the inputs. But the
607- # input names are sometimes meaningful (e.g. DenseFeatures layer). We check
608- # if there is any intersection between the user-provided names and model's
609- # input names. If there is, we assume the names are meaningful and preserve
610- # the dictionary.
611- if (isinstance (self .base_model , tf .keras .Sequential ) and
612- not (set (getattr (self .base_model , 'input_names' , []))
613- & set (inputs .keys ()))):
614- inputs = tf .nest .flatten (inputs )
615- return inputs , labels , sample_weights
616615
617616 def _call_base_model (self , inputs , ** kwargs ):
618- if isinstance (inputs , dict ) and self .base_model ._is_graph_network : # pylint: disable=protected-access
619- base_input_names = getattr (self .base_model , 'input_names' , None )
617+ base_input_names = getattr (self .base_model , 'input_names' , [])
618+ if (isinstance (self .base_model , tf .keras .Sequential ) and
619+ not set (base_input_names ) & set (inputs .keys ())):
620+ # In some cases, Sequential models are automatically compiled to graph
621+ # networks with automatically generated input names. In this case, the
622+ # user isn't expected to know those names, so we just flatten the inputs.
623+ # But the input names are sometimes meaningful (e.g. DenseFeatures layer).
624+ # We check if there is any intersection between the user-provided names
625+ # and model's input names. If there is, we assume the names are meaningful
626+ # and do name-based lookup in the next branch.
627+ inputs = tf .nest .flatten (self ._remove_labels_and_weights (inputs ))
628+ elif self .base_model ._is_graph_network : # pylint: disable=protected-access
620629 if base_input_names :
621630 # Converts input dictionary to a list so it conforms with the model's
622631 # expected input.
623632 inputs = [inputs [name ] for name in base_input_names ]
633+ elif not self ._base_with_labels_in_features :
634+ # Removes labels and sample weights from the input dictionary, since they
635+ # are only used in this class and base model does not need them as inputs.
636+ inputs = self ._remove_labels_and_weights (inputs )
624637 return self .base_model (inputs , ** kwargs )
625638
626639 def _forward_pass (self , inputs , labels , sample_weights , base_model_kwargs ):
@@ -647,7 +660,7 @@ def call(self, inputs, **kwargs):
647660 raise ValueError ('Labels are not in the input. For predicting examples '
648661 'without labels, please use the base model instead.' )
649662
650- inputs , labels , sample_weights = self ._split_inputs (inputs )
663+ labels , sample_weights = self ._extract_labels_and_weights (inputs )
651664 outputs , labeled_loss , metrics , tape = self ._forward_pass (
652665 inputs , labels , sample_weights , kwargs )
653666 self .add_loss (labeled_loss )
@@ -690,8 +703,9 @@ def perturb_on_batch(self, x, **config_kwargs):
690703 A dictionary of NumPy arrays, `SparseTensor`, or `RaggedTensor` objects of
691704 the generated adversarial examples.
692705 """
693- x = tf .nest .map_structure (tf .convert_to_tensor , x , expand_composites = True )
694- inputs , labels , sample_weights = self ._split_inputs (x )
706+ inputs = tf .nest .map_structure (
707+ tf .convert_to_tensor , x , expand_composites = True )
708+ labels , sample_weights = self ._extract_labels_and_weights (inputs )
695709 _ , labeled_loss , _ , tape = self ._forward_pass (inputs , labels ,
696710 sample_weights ,
697711 {'training' : False })
0 commit comments