Commit 6e69973
Convert input dictionary to a list for functional Keras models.
Functional Keras models may expect their input features to be in a
specific order, which may be different from the alphabetic order used
for serializing input dictionaries. Keras `Model` class handles the
different ordering by performing a name lookup before executing the
model's forward pass. However, the name lookup is only performed when
the model is called via high-level interfaces like `model.fit()`, but
not when the model is called directly like `model(input)`. Since
`nsl.keras.AdversarialRegularization` always calls its base model
directly, this creates an interface discrepancy. For example,
```
input = {'a': ..., 'b': ...}
model = tf.keras.Model(
[tf.keras.Input(..., name='b'), tf.keras.Input(..., name='a')], ...)
adv_model = nsl.keras.AdversarialRegularization(model)
... # Compiles both models
model.fit(input) # works
adv_model.fit(input) # error
```
This fix does the name lookup before calling the base model if the base
model is a functional model. Sequential models are excluded because their
feature name may not be specified. Subclassed Keras models are also
excluded because some subclassed models actually expect dictionary-style
input instead of a list.
Fixes #27
PiperOrigin-RevId: 2899384951 parent bfab889 commit 6e69973
File tree
2 files changed
+65
-2
lines changed- neural_structured_learning/keras
2 files changed
+65
-2
lines changedLines changed: 14 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
585 | 585 | | |
586 | 586 | | |
587 | 587 | | |
| 588 | + | |
| 589 | + | |
588 | 590 | | |
589 | 591 | | |
590 | 592 | | |
| |||
599 | 601 | | |
600 | 602 | | |
601 | 603 | | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
602 | 614 | | |
603 | 615 | | |
604 | 616 | | |
605 | 617 | | |
606 | | - | |
| 618 | + | |
607 | 619 | | |
608 | 620 | | |
609 | 621 | | |
| |||
634 | 646 | | |
635 | 647 | | |
636 | 648 | | |
637 | | - | |
| 649 | + | |
638 | 650 | | |
639 | 651 | | |
640 | 652 | | |
| |||
Lines changed: 51 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
42 | 52 | | |
43 | 53 | | |
44 | 54 | | |
| |||
276 | 286 | | |
277 | 287 | | |
278 | 288 | | |
| 289 | + | |
| 290 | + | |
279 | 291 | | |
280 | 292 | | |
281 | 293 | | |
| |||
460 | 472 | | |
461 | 473 | | |
462 | 474 | | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
463 | 514 | | |
464 | 515 | | |
465 | 516 | | |
| |||
0 commit comments