Skip to content

Commit 54fdc9b

Browse files
charitarthchughpre-commit-ci[bot]manujosephv
authored
TabularDataModule: Fix pandas returning a series when calling nunique() (#420)
* Fix pandas returning a series when calling nunique() Fixes #404 (comment) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent 3960f62 commit 54fdc9b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,11 @@ def _update_config(self, config) -> InferredConfig:
290290
raise ValueError(f"{config.task} is an unsupported task.")
291291
if self.train is not None:
292292
categorical_cardinality = [
293-
int(self.train[col].fillna("NA").nunique()) + 1 for col in config.categorical_cols
293+
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
294294
]
295295
else:
296296
categorical_cardinality = [
297-
int(self.train_dataset.data[col].nunique()) + 1 for col in config.categorical_cols
297+
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
298298
]
299299
if getattr(config, "embedding_dims", None) is not None:
300300
embedding_dims = config.embedding_dims

0 commit comments

Comments
 (0)