Skip to content

Commit 6683c91

Browse files
committed
Address PerformanceWarning related to frame.insert()
1 parent 8032d0b commit 6683c91

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/pytorch_tabular/feature_extractor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,21 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
7979
if k in ret_value.keys():
8080
logits_predictions[k].append(ret_value[k].detach().cpu())
8181

82+
logits_dfs = []
8283
for k, v in logits_predictions.items():
8384
v = torch.cat(v, dim=0).numpy()
8485
if v.ndim == 1:
8586
v = v.reshape(-1, 1)
86-
for i in range(v.shape[-1]):
87-
if v.shape[-1] > 1:
88-
X_encoded[f"{k}_{i}"] = v[:, i]
89-
else:
90-
X_encoded[f"{k}"] = v[:, i]
87+
if v.shape[-1] > 1:
88+
temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])})
89+
else:
90+
temp_df = pd.DataFrame({f"{k}": v[:, 0]})
91+
92+
# Append the temp DataFrame to the list
93+
logits_dfs.append(temp_df)
94+
95+
preds = pd.concat(logits_dfs, axis=1)
96+
X_encoded = pd.concat([X_encoded, preds], axis=1)
9197

9298
if self.drop_original:
9399
X_encoded.drop(columns=orig_features, inplace=True)

0 commit comments

Comments
 (0)