Skip to content

Commit b2bf5c1

Browse files
csferngtensorflow-copybara
authored andcommitted
Unset Loss.reduction to prevent double-reduction in AdversarialRegularization.
`AdversarialRegularization` creates a loss wrapper around the provided loss in `compile()` for handling sample weights and loss reduction (aggregation). If the provided loss is a `tf.keras.losses.Loss` object, it comes with loss reduction by default which causes an error in the loss wrapper because the wrapper expects unreduced loss values. This change disables the loss reduction in the provided `Loss` object, so the loss wrapper can function properly. An alternative approach would be disabling the loss reduction in the loss wrapper while doing the loss reduction in the `Loss` object. However, the alternative approach would run into an error when running with `tf.distribute.Strategy`, because the `SUM_OVER_BATCH_SIZE` reduction type requires special logic outside the `Loss` object. Such logic is already implemented in the loss wrapper, so letting the wrapper handle loss reduction looks cleaner. Fixes #21 PiperOrigin-RevId: 272923076
1 parent c1fb3df commit b2bf5c1

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,21 @@ def __init__(self, loss_fn, name, weight):
165165
else:
166166
self.batch_size_reduction = False
167167
super(_LossWrapper, self).__init__(name=name, reduction=reduction)
168-
self.loss_fn = loss_fn
169168
self.weight = weight
169+
if isinstance(loss_fn, tf.keras.losses.Loss) and self.batch_size_reduction:
170+
self.loss_fn = loss_fn.__class__.from_config(loss_fn.get_config())
171+
self.loss_fn.reduction = tf.losses.Reduction.NONE
172+
else:
173+
self.loss_fn = loss_fn
170174

171175
def call(self, y_true, y_pred):
172176
return self.loss_fn(y_true, y_pred)
173177

174178
def __call__(self, *args, **kwargs):
175-
loss_value = super(_LossWrapper, self).__call__(*args, **kwargs)
179+
if isinstance(self.loss_fn, tf.keras.losses.Loss):
180+
loss_value = self.loss_fn(*args, **kwargs)
181+
else:
182+
loss_value = super(_LossWrapper, self).__call__(*args, **kwargs)
176183
if self.batch_size_reduction:
177184
size = tf.cast(tf.size(loss_value), dtype=loss_value.dtype)
178185
loss_value = tf.math.divide_no_nan(tf.math.reduce_sum(loss_value), size)

neural_structured_learning/keras/adversarial_regularization_multi_device_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,19 @@ def _set_up_linear_regression(self, sample_weight=1.0):
8989
w_new = w - learning_rate * (grad_w_labeled_loss + grad_w_adv_loss)
9090
return w, x0, y0, learning_rate, adv_config, w_new
9191

92+
def _get_mirrored_strategy(self):
93+
device_type = 'GPU' if tf.test.is_gpu_available() else 'CPU'
94+
devices = ['{}:{}'.format(device_type, i) for i in range(NUM_REPLICAS)]
95+
return tf.distribute.MirroredStrategy(devices)
96+
9297
def test_train_with_distribution_strategy(self):
9398
w, x0, y0, lr, adv_config, w_new = self._set_up_linear_regression()
9499
inputs = tf.data.Dataset.from_tensor_slices({
95100
'feature': x0,
96101
'label': y0
97102
}).batch(NUM_REPLICAS)
98103

99-
device_type = 'GPU' if tf.test.is_gpu_available() else 'CPU'
100-
devices = ['{}:{}'.format(device_type, i) for i in range(NUM_REPLICAS)]
101-
strategy = tf.distribute.MirroredStrategy(devices)
104+
strategy = self._get_mirrored_strategy()
102105
with strategy.scope():
103106
# Makes sure we are running on multiple devices.
104107
self.assertEqual(NUM_REPLICAS, strategy.num_replicas_in_sync)
@@ -112,6 +115,25 @@ def test_train_with_distribution_strategy(self):
112115
# The updated weight should be the same regardless of the number of devices.
113116
self.assertAllClose(w_new, keras.backend.get_value(model.weights[0]))
114117

118+
def test_train_with_loss_object(self):
119+
w, x0, y0, lr, adv_config, w_new = self._set_up_linear_regression()
120+
inputs = tf.data.Dataset.from_tensor_slices({
121+
'feature': x0,
122+
'label': y0
123+
}).batch(NUM_REPLICAS)
124+
125+
strategy = self._get_mirrored_strategy()
126+
with strategy.scope():
127+
model = build_linear_keras_functional_model(input_shape=(2,), weights=w)
128+
adv_model = adversarial_regularization.AdversarialRegularization(
129+
model, label_keys=['label'], adv_config=adv_config)
130+
adv_model.compile(
131+
optimizer=keras.optimizers.SGD(lr),
132+
loss=tf.keras.losses.MeanSquaredError())
133+
adv_model.fit(x=inputs)
134+
135+
self.assertAllClose(w_new, keras.backend.get_value(model.weights[0]))
136+
115137

116138
if __name__ == '__main__':
117139
tf.compat.v1.enable_v2_behavior()

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,20 @@ def test_train_with_distribution_strategy(self, model_fn):
344344

345345
self.assertAllClose(w_new, keras.backend.get_value(model.weights[0]))
346346

347+
def test_train_with_loss_object(self):
348+
w, x0, y0, lr, adv_config, w_new = self._set_up_linear_regression()
349+
350+
inputs = {'feature': tf.constant(x0), 'label': tf.constant(y0)}
351+
model = build_linear_keras_functional_model(input_shape=(2,), weights=w)
352+
adv_model = adversarial_regularization.AdversarialRegularization(
353+
model, label_keys=['label'], adv_config=adv_config)
354+
adv_model.compile(
355+
optimizer=keras.optimizers.SGD(lr),
356+
loss=tf.keras.losses.MeanSquaredError())
357+
adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
358+
359+
self.assertAllClose(w_new, keras.backend.get_value(model.weights[0]))
360+
347361
def test_train_with_metrics(self):
348362
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
349363

0 commit comments

Comments
 (0)