Skip to content

Commit 20016f8

Browse files
committed
Add pickle_protocol to DataConfig for passing
it down to `torch.save()` when caching datasets to disk
1 parent b504132 commit 20016f8

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk
100+
99101
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100102
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101103
@@ -179,6 +181,11 @@ class DataConfig:
179181
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
180182
)
181183

184+
pickle_protocol: int = field(
185+
default=2,
186+
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
187+
)
188+
182189
dataloader_kwargs: Dict[str, Any] = field(
183190
default_factory=dict,
184191
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,10 @@ def _cache_dataset(self):
484484
self.validation = None
485485

486486
if self.cache_mode is self.CACHE_MODES.DISK:
487-
torch.save(train_dataset, self.cache_dir / "train_dataset")
488-
torch.save(validation_dataset, self.cache_dir / "validation_dataset")
487+
torch.save(train_dataset, self.cache_dir / "train_dataset", pickle_protocol=self.config.pickle_protocol)
488+
torch.save(
489+
validation_dataset, self.cache_dir / "validation_dataset", pickle_protocol=self.config.pickle_protocol
490+
)
489491
elif self.cache_mode is self.CACHE_MODES.MEMORY:
490492
self.train_dataset = train_dataset
491493
self.validation_dataset = validation_dataset

0 commit comments

Comments
 (0)