Skip to content

Commit 3960f62

Browse files
Fix to Tuner change trainer and optimizer configs (#387)
* Fix to Tuner change trainer and optimizer configs * Recreate datamodule when necessary (tuner) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove trainer_config from tuner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove trainer_config from tuner tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8fc27ee commit 3960f62

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

src/pytorch_tabular/tabular_model_tuner.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,14 @@ def _check_assign_config(self, config, param, value):
109109
config[param] = value
110110
else:
111111
raise ValueError(f"{param} is not a valid parameter for {str(config)}")
112-
elif isinstance(config, ModelConfig):
112+
elif isinstance(config, (ModelConfig, OptimizerConfig)):
113113
if hasattr(config, param):
114114
setattr(config, param, value)
115115
else:
116116
raise ValueError(f"{param} is not a valid parameter for {str(config)}")
117117

118118
def _update_configs(
119119
self,
120-
trainer_config: TrainerConfig,
121120
optimizer_config: OptimizerConfig,
122121
model_config: ModelConfig,
123122
params: Dict,
@@ -127,7 +126,9 @@ def _update_configs(
127126
for k, v in params.items():
128127
root, param = k.split("__")
129128
if root.startswith("trainer_config"):
130-
self._check_assign_config(trainer_config, param, v)
129+
raise ValueError(
130+
"The trainer_config is not supported be tuner. Please remove it from tuner parameters!"
131+
)
131132
elif root.startswith("optimizer_config"):
132133
self._check_assign_config(optimizer_config, param, v)
133134
elif root.startswith("model_config.head_config"):
@@ -138,10 +139,10 @@ def _update_configs(
138139
else:
139140
raise ValueError(
140141
f"{k} is not in the proper format. Use __ to separate the "
141-
"root and param. for eg. `training_config__batch_size` should be "
142-
"used to update the batch_size parameter in the training_config"
142+
"root and param. for eg. `optimizer_config__optimizer` should be "
143+
"used to update the optimizer parameter in the optimizer_config"
143144
)
144-
return trainer_config, optimizer_config, model_config
145+
return optimizer_config, model_config
145146

146147
def tune(
147148
self,
@@ -251,9 +252,11 @@ def tune(
251252
iterator = ParameterSampler(search_space, n_iter=n_trials, random_state=random_state)
252253
else:
253254
raise NotImplementedError(f"{strategy} is not implemented yet.")
255+
254256
if progress_bar:
255257
iterator = track(iterator, description=f"[green]{strategy.replace('_',' ').title()}...", total=n_trials)
256258
verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False)
259+
257260
temp_tabular_model = TabularModel(
258261
data_config=self.data_config,
259262
model_config=self.model_config,
@@ -262,11 +265,13 @@ def tune(
262265
verbose=verbose_tabular_model,
263266
**self.tabular_model_init_kwargs,
264267
)
268+
265269
prep_dl_kwargs, prep_model_kwargs, train_kwargs = temp_tabular_model._split_kwargs(kwargs)
266270
if "seed" not in prep_dl_kwargs:
267271
prep_dl_kwargs["seed"] = random_state
268272
datamodule = temp_tabular_model.prepare_dataloader(train=train, validation=validation, **prep_dl_kwargs)
269273
validation = validation if validation is not None else datamodule.validation_dataset.data
274+
270275
if isinstance(metric, str):
271276
# metric = metric_to_pt_metric(metric)
272277
is_callable_metric = False
@@ -275,6 +280,7 @@ def tune(
275280
is_callable_metric = True
276281
metric_str = metric.__name__
277282
del temp_tabular_model
283+
278284
trials = []
279285
best_model = None
280286
best_score = 0.0
@@ -286,9 +292,7 @@ def tune(
286292
optimizer_config_t = deepcopy(self.optimizer_config)
287293
model_config_t = deepcopy(self.model_config)
288294

289-
trainer_config_t, optimizer_config_t, model_config_t = self._update_configs(
290-
trainer_config_t, optimizer_config_t, model_config_t, params
291-
)
295+
optimizer_config_t, model_config_t = self._update_configs(optimizer_config_t, model_config_t, params)
292296
# Initialize Tabular model using the new config
293297
tabular_model_t = TabularModel(
294298
data_config=self.data_config,
@@ -298,6 +302,7 @@ def tune(
298302
verbose=verbose_tabular_model,
299303
**self.tabular_model_init_kwargs,
300304
)
305+
301306
if cv is not None:
302307
cv_verbose = cv_kwargs.pop("verbose", False)
303308
cv_kwargs.pop("handle_oom", None)
@@ -317,7 +322,7 @@ def tune(
317322
"Set ignore_oom=True to ignore this error."
318323
)
319324
else:
320-
params.update({metric_str: "OOM"})
325+
params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
321326
else:
322327
params.update({metric_str: cv_agg_func(cv_scores)})
323328
else:
@@ -334,7 +339,7 @@ def tune(
334339
"Out of memory error occurred during training. " "Set ignore_oom=True to ignore this error."
335340
)
336341
else:
337-
params.update({metric_str: "OOM"})
342+
params.update({metric_str: (np.inf if mode == "min" else -np.inf)})
338343
else:
339344
if is_callable_metric:
340345
preds = tabular_model_t.predict(validation, include_input_features=False)
@@ -380,6 +385,7 @@ def tune(
380385

381386
if return_best_model and best_model is not None:
382387
best_model.datamodule = datamodule
388+
383389
return self.OUTPUT(trials_df, best_params, best_score, best_model)
384390
else:
385391
return self.OUTPUT(trials_df, best_params, best_score, None)

tests/test_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TabNetModelConfig,
2121
)
2222
from pytorch_tabular.ssl_models import DenoisingAutoEncoderConfig
23-
from scipy.stats import randint, uniform
23+
from scipy.stats import uniform
2424
from sklearn.metrics import accuracy_score, r2_score
2525
from sklearn.model_selection import KFold
2626

@@ -852,14 +852,12 @@ def test_tuner(
852852
search_space = {
853853
"model_config__layers": ["8-4", "16-8"],
854854
"model_config.head_config__dropout": [0.1, 0.2],
855-
"trainer_config__batch_size": [32],
856855
"optimizer_config__optimizer": ["RAdam", "AdamW"],
857856
}
858857
else:
859858
search_space = {
860859
"model_config__layers": ["8-4", "16-8"],
861860
"model_config.head_config__dropout": uniform(0, 0.5),
862-
"trainer_config__batch_size": randint(32, 64),
863861
"optimizer_config__optimizer": ["RAdam", "AdamW"],
864862
}
865863
result = tuner.tune(

0 commit comments

Comments
 (0)