@@ -109,15 +109,14 @@ def _check_assign_config(self, config, param, value):
109109 config [param ] = value
110110 else :
111111 raise ValueError (f"{ param } is not a valid parameter for { str (config )} " )
112- elif isinstance (config , ModelConfig ):
112+ elif isinstance (config , ( ModelConfig , OptimizerConfig ) ):
113113 if hasattr (config , param ):
114114 setattr (config , param , value )
115115 else :
116116 raise ValueError (f"{ param } is not a valid parameter for { str (config )} " )
117117
118118 def _update_configs (
119119 self ,
120- trainer_config : TrainerConfig ,
121120 optimizer_config : OptimizerConfig ,
122121 model_config : ModelConfig ,
123122 params : Dict ,
@@ -127,7 +126,9 @@ def _update_configs(
127126 for k , v in params .items ():
128127 root , param = k .split ("__" )
129128 if root .startswith ("trainer_config" ):
130- self ._check_assign_config (trainer_config , param , v )
129+ raise ValueError (
130+ "The trainer_config is not supported be tuner. Please remove it from tuner parameters!"
131+ )
131132 elif root .startswith ("optimizer_config" ):
132133 self ._check_assign_config (optimizer_config , param , v )
133134 elif root .startswith ("model_config.head_config" ):
@@ -138,10 +139,10 @@ def _update_configs(
138139 else :
139140 raise ValueError (
140141 f"{ k } is not in the proper format. Use __ to separate the "
141- "root and param. for eg. `training_config__batch_size ` should be "
142- "used to update the batch_size parameter in the training_config "
142+ "root and param. for eg. `optimizer_config__optimizer ` should be "
143+ "used to update the optimizer parameter in the optimizer_config "
143144 )
144- return trainer_config , optimizer_config , model_config
145+ return optimizer_config , model_config
145146
146147 def tune (
147148 self ,
@@ -251,9 +252,11 @@ def tune(
251252 iterator = ParameterSampler (search_space , n_iter = n_trials , random_state = random_state )
252253 else :
253254 raise NotImplementedError (f"{ strategy } is not implemented yet." )
255+
254256 if progress_bar :
255257 iterator = track (iterator , description = f"[green]{ strategy .replace ('_' ,' ' ).title ()} ..." , total = n_trials )
256258 verbose_tabular_model = self .tabular_model_init_kwargs .pop ("verbose" , False )
259+
257260 temp_tabular_model = TabularModel (
258261 data_config = self .data_config ,
259262 model_config = self .model_config ,
@@ -262,11 +265,13 @@ def tune(
262265 verbose = verbose_tabular_model ,
263266 ** self .tabular_model_init_kwargs ,
264267 )
268+
265269 prep_dl_kwargs , prep_model_kwargs , train_kwargs = temp_tabular_model ._split_kwargs (kwargs )
266270 if "seed" not in prep_dl_kwargs :
267271 prep_dl_kwargs ["seed" ] = random_state
268272 datamodule = temp_tabular_model .prepare_dataloader (train = train , validation = validation , ** prep_dl_kwargs )
269273 validation = validation if validation is not None else datamodule .validation_dataset .data
274+
270275 if isinstance (metric , str ):
271276 # metric = metric_to_pt_metric(metric)
272277 is_callable_metric = False
@@ -275,6 +280,7 @@ def tune(
275280 is_callable_metric = True
276281 metric_str = metric .__name__
277282 del temp_tabular_model
283+
278284 trials = []
279285 best_model = None
280286 best_score = 0.0
@@ -286,9 +292,7 @@ def tune(
286292 optimizer_config_t = deepcopy (self .optimizer_config )
287293 model_config_t = deepcopy (self .model_config )
288294
289- trainer_config_t , optimizer_config_t , model_config_t = self ._update_configs (
290- trainer_config_t , optimizer_config_t , model_config_t , params
291- )
295+ optimizer_config_t , model_config_t = self ._update_configs (optimizer_config_t , model_config_t , params )
292296 # Initialize Tabular model using the new config
293297 tabular_model_t = TabularModel (
294298 data_config = self .data_config ,
@@ -298,6 +302,7 @@ def tune(
298302 verbose = verbose_tabular_model ,
299303 ** self .tabular_model_init_kwargs ,
300304 )
305+
301306 if cv is not None :
302307 cv_verbose = cv_kwargs .pop ("verbose" , False )
303308 cv_kwargs .pop ("handle_oom" , None )
@@ -317,7 +322,7 @@ def tune(
317322 "Set ignore_oom=True to ignore this error."
318323 )
319324 else :
320- params .update ({metric_str : "OOM" })
325+ params .update ({metric_str : ( np . inf if mode == "min" else - np . inf ) })
321326 else :
322327 params .update ({metric_str : cv_agg_func (cv_scores )})
323328 else :
@@ -334,7 +339,7 @@ def tune(
334339 "Out of memory error occurred during training. " "Set ignore_oom=True to ignore this error."
335340 )
336341 else :
337- params .update ({metric_str : "OOM" })
342+ params .update ({metric_str : ( np . inf if mode == "min" else - np . inf ) })
338343 else :
339344 if is_callable_metric :
340345 preds = tabular_model_t .predict (validation , include_input_features = False )
@@ -380,6 +385,7 @@ def tune(
380385
381386 if return_best_model and best_model is not None :
382387 best_model .datamodule = datamodule
388+
383389 return self .OUTPUT (trials_df , best_params , best_score , best_model )
384390 else :
385391 return self .OUTPUT (trials_df , best_params , best_score , None )
0 commit comments