Skip to content

Commit e848a7a

Browse files
csferngtensorflow-copybara
authored andcommitted
Minor improvement in AdversarialRegularization.
PiperOrigin-RevId: 298399230
1 parent 83d3f76 commit e848a7a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,11 @@ def _build_labeled_metrics(self, output_names, labeled_losses):
572572
self._labeled_metrics.append(per_output_metrics)
573573

574574
def _get_or_create_base_output_names(self, outputs):
575-
num_output = len(tf.nest.flatten(outputs))
576-
return getattr(self.base_model, 'output_names',
577-
['output_%d' % i for i in range(1, num_output + 1)])
575+
output_names = getattr(self.base_model, 'output_names', None)
576+
if not output_names:
577+
num_output = len(tf.nest.flatten(outputs))
578+
output_names = ['output_%d' % i for i in range(1, num_output + 1)]
579+
return output_names
578580

579581
def _compute_total_loss(self, labels, outputs, sample_weights=None):
580582
# `None` is passed instead of the actual metrics in order to skip computing

0 commit comments

Comments
 (0)