|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | import logging |
14 | 14 | import os |
15 | | -import signal |
16 | 15 |
|
17 | 16 | import numpy as np |
18 | 17 | import xgboost as xgb |
19 | 18 | from sklearn.model_selection import RepeatedKFold, RepeatedStratifiedKFold |
20 | 19 |
|
21 | 20 | from sagemaker_algorithm_toolkit import exceptions as exc |
22 | 21 | from sagemaker_algorithm_toolkit.channel_validation import Channel |
23 | | -from sagemaker_xgboost_container import checkpointing, distributed |
| 22 | +from sagemaker_xgboost_container import distributed |
24 | 23 | from sagemaker_xgboost_container.algorithm_mode import channel_validation as cv |
25 | 24 | from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv |
26 | 25 | from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod |
27 | 26 | from sagemaker_xgboost_container.algorithm_mode import train_utils |
28 | | -from sagemaker_xgboost_container.callback import add_debugging |
29 | | -from sagemaker_xgboost_container.constants.sm_env_constants import SM_OUTPUT_DATA_DIR |
| 27 | +from sagemaker_xgboost_container.callback import add_debugging, get_callbacks |
| 28 | +from sagemaker_xgboost_container.constants.sm_env_constants import ( |
| 29 | + SM_NUM_GPUS, |
| 30 | + SM_OUTPUT_DATA_DIR, |
| 31 | +) |
30 | 32 | from sagemaker_xgboost_container.constants.xgb_constants import ( |
31 | 33 | CUSTOMER_ERRORS, |
32 | 34 | MODEL_NAME, |
33 | | - XGB_MAXIMIZE_METRICS, |
34 | 35 | ) |
35 | 36 | from sagemaker_xgboost_container.data_utils import ( |
36 | 37 | check_data_redundancy, |
|
39 | 40 | get_size, |
40 | 41 | validate_data_file_path, |
41 | 42 | ) |
| 43 | +from sagemaker_xgboost_container.distributed_gpu import distributed_gpu_training |
42 | 44 | from sagemaker_xgboost_container.prediction_utils import ValidationPredictionRecorder |
43 | 45 |
|
44 | 46 | logger = logging.getLogger(__name__) |
45 | 47 |
|
46 | | - |
47 | | -def add_sigterm_handler(model_dir, is_master): |
48 | | - """Stop training and cleanup model directory when SIGTERM is received. |
49 | | -
|
50 | | - Model directory is only cleaned if is_master is True. Otherwise program terminates. |
51 | | -
|
52 | | - :param model_dir: Directory where model is saved |
53 | | - :param is_master: True if single node training, or the current node is the master node in distributed training |
54 | | - """ |
55 | | - |
56 | | - def _terminate(): |
57 | | - os._exit(0) |
58 | | - |
59 | | - def _cleanup_files(signo, frame): |
60 | | - if is_master: |
61 | | - train_utils.cleanup_dir(model_dir, MODEL_NAME) |
62 | | - |
63 | | - _terminate() |
64 | | - |
65 | | - signal.signal(signal.SIGTERM, _cleanup_files) |
| 48 | +DOCUMENTATION_LINK = "https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html" |
66 | 49 |
|
67 | 50 |
|
68 | 51 | def get_validated_dmatrices( |
@@ -169,50 +152,86 @@ def sagemaker_train( |
169 | 152 | # Obtain information about training resources to determine which distributed setup to use, if needed. |
170 | 153 | num_hosts = len(sm_hosts) |
171 | 154 |
|
172 | | - train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices( |
173 | | - train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val |
174 | | - ) |
175 | 155 | checkpoint_dir = checkpoint_config.get("LocalPath", None) |
176 | 156 |
|
177 | | - train_args = dict( |
178 | | - train_cfg=validated_train_config, |
179 | | - train_dmatrix=train_dmatrix, |
180 | | - val_dmatrix=val_dmatrix, |
181 | | - train_val_dmatrix=train_val_dmatrix, |
182 | | - model_dir=model_dir, |
183 | | - checkpoint_dir=checkpoint_dir, |
184 | | - ) |
| 157 | + num_gpus = int(os.getenv(SM_NUM_GPUS, 0)) |
| 158 | + logging.info(f"Determined {num_gpus} GPU(s) available on the instance.") |
| 159 | + tree_method_hp = validated_train_config.get("tree_method") |
| 160 | + |
| 161 | + is_dask_job = validated_train_config.pop("use_dask_gpu_training", "false") |
185 | 162 |
|
186 | | - if num_hosts > 1: |
187 | | - # Wait for hosts to find each other |
188 | | - logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts)) |
189 | | - distributed.wait_hostname_resolution(sm_hosts) |
| 163 | + if is_dask_job == "true": |
| 164 | + gpu_train_validation_errors = distributed_gpu_training.validate_gpu_train_configuration( |
| 165 | + tree_method_hp=tree_method_hp, |
| 166 | + num_hosts=num_hosts, |
| 167 | + num_gpus=num_gpus, |
| 168 | + input_mode=input_mode, |
| 169 | + input_format=file_type, |
| 170 | + data_config=validated_data_config, |
| 171 | + ) |
190 | 172 |
|
191 | | - if not train_dmatrix: |
| 173 | + if gpu_train_validation_errors: |
| 174 | + raise exc.UserError(f"Some configurations unsuitable for Dask GPU training were found: " |
| 175 | + f"{'. '.join(gpu_train_validation_errors)}") |
| 176 | + |
| 177 | + logging.info("Going to run distributed GPU training through Dask.") |
| 178 | + distributed_gpu_training.run_training_with_dask( |
| 179 | + hyperparameters=validated_train_config, |
| 180 | + train_path=train_path, |
| 181 | + validation_path=val_path, |
| 182 | + model_dir=model_dir, |
| 183 | + content_type=file_type, |
| 184 | + sm_hosts=sm_hosts, |
| 185 | + current_host=sm_current_host, |
| 186 | + checkpoint_dir=checkpoint_dir, |
| 187 | + num_gpus=num_gpus, |
| 188 | + ) |
| 189 | + else: |
| 190 | + if num_gpus > 1: |
192 | 191 | logging.warning( |
193 | | - "Host {} does not have data. Will broadcast to cluster and will not be used in distributed" |
194 | | - " training.".format(sm_current_host) |
| 192 | + f"If you're using GPU training, not all GPUs on the instance will be used. " |
| 193 | + f"See how to use all GPUs at {DOCUMENTATION_LINK}" |
195 | 194 | ) |
196 | | - distributed.rabit_run( |
197 | | - exec_fun=train_job, |
198 | | - args=train_args, |
199 | | - include_in_training=(train_dmatrix is not None), |
200 | | - hosts=sm_hosts, |
201 | | - current_host=sm_current_host, |
202 | | - update_rabit_args=True, |
| 195 | + |
| 196 | + train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices( |
| 197 | + train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val |
203 | 198 | ) |
204 | | - elif num_hosts == 1: |
205 | | - if train_dmatrix: |
206 | | - if validation_channel: |
207 | | - if not val_dmatrix: |
| 199 | + train_args = dict( |
| 200 | + train_cfg=validated_train_config, |
| 201 | + train_dmatrix=train_dmatrix, |
| 202 | + val_dmatrix=val_dmatrix, |
| 203 | + train_val_dmatrix=train_val_dmatrix, |
| 204 | + model_dir=model_dir, |
| 205 | + checkpoint_dir=checkpoint_dir, |
| 206 | + ) |
| 207 | + if num_hosts > 1: |
| 208 | + # Wait for hosts to find each other |
| 209 | + logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts)) |
| 210 | + distributed.wait_hostname_resolution(sm_hosts) |
| 211 | + if not train_dmatrix: |
| 212 | + logging.warning( |
| 213 | + "Host {} does not have data. Will broadcast to cluster and will not be used in distributed" |
| 214 | + " training.".format(sm_current_host) |
| 215 | + ) |
| 216 | + distributed.rabit_run( |
| 217 | + exec_fun=train_job, |
| 218 | + args=train_args, |
| 219 | + include_in_training=(train_dmatrix is not None), |
| 220 | + hosts=sm_hosts, |
| 221 | + current_host=sm_current_host, |
| 222 | + update_rabit_args=True, |
| 223 | + ) |
| 224 | + elif num_hosts == 1: |
| 225 | + if train_dmatrix: |
| 226 | + if validation_channel and not val_dmatrix: |
208 | 227 | raise exc.UserError("No data in validation channel path {}".format(val_path)) |
209 | | - logging.info("Single node training.") |
210 | | - train_args.update({"is_master": True}) |
211 | | - train_job(**train_args) |
| 228 | + logging.info("Single node training.") |
| 229 | + train_args.update({"is_master": True}) |
| 230 | + train_job(**train_args) |
| 231 | + else: |
| 232 | + raise exc.UserError("No data in training channel path {}".format(train_path)) |
212 | 233 | else: |
213 | | - raise exc.UserError("No data in training channel path {}".format(train_path)) |
214 | | - else: |
215 | | - raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1") |
| 234 | + raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1") |
216 | 235 |
|
217 | 236 |
|
218 | 237 | def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_dir, checkpoint_dir, is_master): |
@@ -259,11 +278,12 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di |
259 | 278 |
|
260 | 279 | try: |
261 | 280 | kfold = train_cfg.pop("_kfold", None) |
| 281 | + watchlist = [(train_dmatrix, "train")] |
| 282 | + if val_dmatrix is not None: |
| 283 | + watchlist.append((val_dmatrix, "validation")) |
262 | 284 |
|
263 | 285 | if kfold is None: |
264 | | - xgb_model, iteration, callbacks, watchlist = get_callbacks_watchlist( |
265 | | - train_dmatrix=train_dmatrix, |
266 | | - val_dmatrix=val_dmatrix, |
| 286 | + xgb_model, iteration, callbacks = get_callbacks( |
267 | 287 | model_dir=model_dir, |
268 | 288 | checkpoint_dir=checkpoint_dir, |
269 | 289 | early_stopping_data_name=early_stopping_data_name, |
@@ -322,9 +342,7 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di |
322 | 342 | cv_train_dmatrix = train_val_dmatrix.slice(train_idx) |
323 | 343 | cv_val_dmatrix = train_val_dmatrix.slice(val_idx) |
324 | 344 |
|
325 | | - xgb_model, iteration, callbacks, watchlist = get_callbacks_watchlist( |
326 | | - train_dmatrix=cv_train_dmatrix, |
327 | | - val_dmatrix=cv_val_dmatrix, |
| 345 | + xgb_model, iteration, callbacks = get_callbacks( |
328 | 346 | model_dir=model_dir, |
329 | 347 | checkpoint_dir=checkpoint_dir, |
330 | 348 | early_stopping_data_name=early_stopping_data_name, |
@@ -391,61 +409,6 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di |
391 | 409 | logging.debug("Stored trained model {} at {}".format(fold, model_location)) |
392 | 410 |
|
393 | 411 |
|
394 | | -def get_callbacks_watchlist( |
395 | | - train_dmatrix, |
396 | | - val_dmatrix, |
397 | | - model_dir, |
398 | | - checkpoint_dir, |
399 | | - early_stopping_data_name, |
400 | | - early_stopping_metric, |
401 | | - early_stopping_rounds, |
402 | | - save_model_on_termination, |
403 | | - is_master, |
404 | | - fold=None, |
405 | | -): |
406 | | - if checkpoint_dir and fold is not None: |
407 | | - checkpoint_dir = os.path.join(checkpoint_dir, f"model-{fold}") |
408 | | - |
409 | | - # Set callbacks |
410 | | - xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir) |
411 | | - if xgb_model is not None: |
412 | | - if fold is not None: |
413 | | - xgb_model = f"{xgb_model}-{fold}" |
414 | | - logging.info("Checkpoint loaded from %s", xgb_model) |
415 | | - logging.info("Resuming from iteration %s", iteration) |
416 | | - |
417 | | - callbacks = [] |
418 | | - callbacks.append(xgb.callback.EvaluationMonitor()) |
419 | | - if checkpoint_dir: |
420 | | - save_checkpoint = xgb.callback.TrainingCheckPoint( |
421 | | - directory=checkpoint_dir, iterations=iteration, name=checkpointing.CHECKPOINT_FILENAME |
422 | | - ) |
423 | | - callbacks.append(save_checkpoint) |
424 | | - |
425 | | - if save_model_on_termination == "true": |
426 | | - model_name = f"{MODEL_NAME}-{fold}" if fold is not None else MODEL_NAME |
427 | | - save_intermediate_model = checkpointing.SaveIntermediateModelCallBack(model_dir, model_name, is_master) |
428 | | - callbacks.append(save_intermediate_model) |
429 | | - add_sigterm_handler(model_dir, is_master) |
430 | | - |
431 | | - if early_stopping_data_name and early_stopping_metric and early_stopping_rounds: |
432 | | - maximize = early_stopping_metric in XGB_MAXIMIZE_METRICS |
433 | | - early_stop = xgb.callback.EarlyStopping( |
434 | | - rounds=early_stopping_rounds, |
435 | | - data_name=early_stopping_data_name, |
436 | | - metric_name=early_stopping_metric, |
437 | | - maximize=maximize, |
438 | | - save_best=True, |
439 | | - ) |
440 | | - callbacks.append(early_stop) |
441 | | - |
442 | | - watchlist = [(train_dmatrix, "train")] |
443 | | - if val_dmatrix is not None: |
444 | | - watchlist.append((val_dmatrix, "validation")) |
445 | | - |
446 | | - return xgb_model, iteration, callbacks, watchlist |
447 | | - |
448 | | - |
449 | 412 | def print_cv_metric(num_round, evals_results): |
450 | 413 | cv_eval_report = f"[{num_round}]" |
451 | 414 | for metric_name in evals_results[0]["train"]: |
|
0 commit comments