Skip to content

Commit c790615

Browse files
committed
Save the index of the original dataframe in TabularDataset
so that it can be restored when accessing `TabularDataset.data`
1 parent 20016f8 commit c790615

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self.task = task
6262
self.n = data.shape[0]
6363
self.target = target
64+
self.index = data.index
6465
if target:
6566
self.y = data[target].astype(np.float32).values
6667
if isinstance(target, str):
@@ -87,11 +88,12 @@ def data(self):
8788
data = pd.DataFrame(
8889
np.concatenate([self.categorical_X, self.continuous_X], axis=1),
8990
columns=self.categorical_cols + self.continuous_cols,
91+
index=self.index,
9092
)
9193
elif self.continuous_cols:
92-
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols)
94+
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols, index=self.index)
9395
elif self.categorical_cols:
94-
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols)
96+
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols, index=self.index)
9597
else:
9698
data = pd.DataFrame()
9799
for i, t in enumerate(self.target):
@@ -474,6 +476,7 @@ def _cache_dataset(self):
474476
target=self.target,
475477
)
476478
self.train = None
479+
477480
validation_dataset = TabularDataset(
478481
task=self.config.task,
479482
data=self.validation,

0 commit comments

Comments
 (0)