Commit b2bf5c1
Unset Loss.reduction to prevent double-reduction in AdversarialRegularization.
`AdversarialRegularization` creates a loss wrapper around the provided loss in
`compile()` for handling sample weights and loss reduction (aggregation). If the
provided loss is a `tf.keras.losses.Loss` object, it comes with loss reduction
by default which causes an error in the loss wrapper because the wrapper expects
unreduced loss values.
This change disables the loss reduction in the provided `Loss` object, so the
loss wrapper can function properly. An alternative approach would be disabling
the loss reduction in the loss wrapper while doing the loss reduction in the
`Loss` object. However, the alternative approach would run into an error when
running with `tf.distribute.Strategy`, because the `SUM_OVER_BATCH_SIZE`
reduction type requires special logic outside the `Loss` object. Such logic is
already implemented in the loss wrapper, so letting the wrapper handle loss
reduction looks cleaner.
Fixes #21
PiperOrigin-RevId: 2729230761 parent c1fb3df commit b2bf5c1
File tree
3 files changed
+48
-5
lines changed- neural_structured_learning/keras
3 files changed
+48
-5
lines changedLines changed: 9 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
165 | 165 | | |
166 | 166 | | |
167 | 167 | | |
168 | | - | |
169 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
170 | 174 | | |
171 | 175 | | |
172 | 176 | | |
173 | 177 | | |
174 | 178 | | |
175 | | - | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
176 | 183 | | |
177 | 184 | | |
178 | 185 | | |
| |||
Lines changed: 25 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
92 | 97 | | |
93 | 98 | | |
94 | 99 | | |
95 | 100 | | |
96 | 101 | | |
97 | 102 | | |
98 | 103 | | |
99 | | - | |
100 | | - | |
101 | | - | |
| 104 | + | |
102 | 105 | | |
103 | 106 | | |
104 | 107 | | |
| |||
112 | 115 | | |
113 | 116 | | |
114 | 117 | | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
115 | 137 | | |
116 | 138 | | |
117 | 139 | | |
| |||
Lines changed: 14 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
344 | 344 | | |
345 | 345 | | |
346 | 346 | | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
347 | 361 | | |
348 | 362 | | |
349 | 363 | | |
| |||
0 commit comments