File tree Expand file tree Collapse file tree 2 files changed +11
-2
lines changed
Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -96,6 +96,8 @@ class DataConfig:
9696 handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797 unknown
9898
99+ pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk
100+
99101 dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100102 https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101103
@@ -179,6 +181,11 @@ class DataConfig:
179181 metadata = {"help" : "Whether or not to handle missing values in categorical columns as unknown" },
180182 )
181183
184+ pickle_protocol : int = field (
185+ default = 2 ,
186+ metadata = {"help" : "pickle protocol version passed to `torch.save` for dataset caching to disk" },
187+ )
188+
182189 dataloader_kwargs : Dict [str , Any ] = field (
183190 default_factory = dict ,
184191 metadata = {"help" : "Additional kwargs to be passed to PyTorch DataLoader." },
Original file line number Diff line number Diff line change @@ -484,8 +484,10 @@ def _cache_dataset(self):
484484 self .validation = None
485485
486486 if self .cache_mode is self .CACHE_MODES .DISK :
487- torch .save (train_dataset , self .cache_dir / "train_dataset" )
488- torch .save (validation_dataset , self .cache_dir / "validation_dataset" )
487+ torch .save (train_dataset , self .cache_dir / "train_dataset" , pickle_protocol = self .config .pickle_protocol )
488+ torch .save (
489+ validation_dataset , self .cache_dir / "validation_dataset" , pickle_protocol = self .config .pickle_protocol
490+ )
489491 elif self .cache_mode is self .CACHE_MODES .MEMORY :
490492 self .train_dataset = train_dataset
491493 self .validation_dataset = validation_dataset
You can’t perform that action at this time.
0 commit comments