Skip to content

Commit cc3504a

Browse files
authored
Protection for MDN Head misuse (#448)
* add protection for MDN heads * minor bug fix
1 parent 9dfec99 commit cc3504a

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,10 @@ def __post_init__(self):
922922

923923
if self.task != "backbone":
924924
assert self.head in dir(heads.blocks), f"{self.head} is not a valid head"
925+
if hasattr(self, "_config_name") and self._config_name != "MDNConfig":
926+
assert self.head != "MixtureDensityHead", "MixtureDensityHead is not supported as a head for regular "
927+
"models. Use `MDNConfig` instead. Please see Probabilistic Regression with MDN How-to-Guide in "
928+
"documentation for the right usage."
925929
_head_callable = getattr(heads.blocks, self.head)
926930
ideal_head_config = _head_callable._config_template
927931
invalid_keys = set(self.head_config.keys()) - set(ideal_head_config.__dict__.keys())

0 commit comments

Comments
 (0)