Skip to content

Commit 3f0a15c

Browse files
Bug fix for saving and loading custom loss functions (#415)
* bug fix for custom loss weight state dict error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed custom loss, metric setting in load model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 54fdc9b commit 3f0a15c

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -459,21 +459,40 @@ def load_model(cls, dir: str, map_location=None, strict=True):
459459
model_args["optimizer_params"] = {} # For compatibility. Not Used
460460

461461
# Initializing with default metrics, losses, and optimizers. Will revert once initialized
462-
model = model_callable.load_from_checkpoint(
463-
checkpoint_path=os.path.join(dir, "model.ckpt"),
464-
map_location=map_location,
465-
strict=strict,
466-
**model_args,
467-
)
468-
# Updating config with custom parameters for experiment tracking
469-
if custom_params.get("custom_loss") is not None:
470-
model.custom_loss = custom_params["custom_loss"]
471-
if custom_params.get("custom_metrics") is not None:
472-
model.custom_metrics = custom_params["custom_metrics"]
462+
try:
463+
model = model_callable.load_from_checkpoint(
464+
checkpoint_path=os.path.join(dir, "model.ckpt"),
465+
map_location=map_location,
466+
strict=strict,
467+
**model_args,
468+
)
469+
except RuntimeError as e:
470+
if (
471+
"Unexpected key(s) in state_dict" in str(e)
472+
and "loss.weight" in str(e)
473+
and "custom_loss.weight" in str(e)
474+
):
475+
# Custom loss will be loaded after the model is initialized
476+
# continuing with strict=False
477+
model = model_callable.load_from_checkpoint(
478+
checkpoint_path=os.path.join(dir, "model.ckpt"),
479+
map_location=map_location,
480+
strict=False,
481+
**model_args,
482+
)
483+
else:
484+
raise e
473485
if custom_params.get("custom_optimizer") is not None:
474486
model.custom_optimizer = custom_params["custom_optimizer"]
475487
if custom_params.get("custom_optimizer_params") is not None:
476488
model.custom_optimizer_params = custom_params["custom_optimizer_params"]
489+
if custom_params.get("custom_loss") is not None:
490+
model.loss = custom_params["custom_loss"]
491+
if custom_params.get("custom_metrics") is not None:
492+
model.custom_metrics = custom_params.get("custom_metrics")
493+
model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")]
494+
model.hparams.metrics_params = [{}]
495+
model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs")
477496
model._setup_loss()
478497
model._setup_metrics()
479498
tabular_model = cls(config=config, model_callable=model_callable)

0 commit comments

Comments
 (0)