Skip to content

Commit 342a097

Browse files
Dask Flow Integration (#374)
* Dask Flow Integration * Remove explicit persist of data * Updating error messages, some refactoring --------- Co-authored-by: Cho <choeric@amazon.com>
1 parent 7056e3a commit 342a097

File tree

9 files changed

+330
-151
lines changed

9 files changed

+330
-151
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ __pycache__
99
.mypy_cache/
1010
.idea/
1111
.DS_Store
12+
test.parquet

src/sagemaker_xgboost_container/algorithm_mode/hyperparameter_validation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ def interaction_constraints_validator(value, dependencies):
338338
hpv.CategoricalHyperparameter(name="deterministic_histogram", range=["true", "false"], required=False),
339339
hpv.CategoricalHyperparameter(name="sampling_method", range=["uniform", "gradient_based"], required=False),
340340
hpv.IntegerHyperparameter(name="prob_buffer_row", range=hpv.Interval(min_open=1.0), required=False),
341+
# Not an XGB training HP, but is used to determine which distributed training framework to use by SM XGB.
342+
hpv.CategoricalHyperparameter(name="use_dask_gpu_training", range=["true", "false"], required=False),
341343
)
342344

343345
hyperparameters.declare_alias("eta", "learning_rate")

src/sagemaker_xgboost_container/algorithm_mode/train.py

Lines changed: 84 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,26 @@
1212
# language governing permissions and limitations under the License.
1313
import logging
1414
import os
15-
import signal
1615

1716
import numpy as np
1817
import xgboost as xgb
1918
from sklearn.model_selection import RepeatedKFold, RepeatedStratifiedKFold
2019

2120
from sagemaker_algorithm_toolkit import exceptions as exc
2221
from sagemaker_algorithm_toolkit.channel_validation import Channel
23-
from sagemaker_xgboost_container import checkpointing, distributed
22+
from sagemaker_xgboost_container import distributed
2423
from sagemaker_xgboost_container.algorithm_mode import channel_validation as cv
2524
from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv
2625
from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod
2726
from sagemaker_xgboost_container.algorithm_mode import train_utils
28-
from sagemaker_xgboost_container.callback import add_debugging
29-
from sagemaker_xgboost_container.constants.sm_env_constants import SM_OUTPUT_DATA_DIR
27+
from sagemaker_xgboost_container.callback import add_debugging, get_callbacks
28+
from sagemaker_xgboost_container.constants.sm_env_constants import (
29+
SM_NUM_GPUS,
30+
SM_OUTPUT_DATA_DIR,
31+
)
3032
from sagemaker_xgboost_container.constants.xgb_constants import (
3133
CUSTOMER_ERRORS,
3234
MODEL_NAME,
33-
XGB_MAXIMIZE_METRICS,
3435
)
3536
from sagemaker_xgboost_container.data_utils import (
3637
check_data_redundancy,
@@ -39,30 +40,12 @@
3940
get_size,
4041
validate_data_file_path,
4142
)
43+
from sagemaker_xgboost_container.distributed_gpu import distributed_gpu_training
4244
from sagemaker_xgboost_container.prediction_utils import ValidationPredictionRecorder
4345

4446
logger = logging.getLogger(__name__)
4547

46-
47-
def add_sigterm_handler(model_dir, is_master):
48-
"""Stop training and cleanup model directory when SIGTERM is received.
49-
50-
Model directory is only cleaned if is_master is True. Otherwise program terminates.
51-
52-
:param model_dir: Directory where model is saved
53-
:param is_master: True if single node training, or the current node is the master node in distributed training
54-
"""
55-
56-
def _terminate():
57-
os._exit(0)
58-
59-
def _cleanup_files(signo, frame):
60-
if is_master:
61-
train_utils.cleanup_dir(model_dir, MODEL_NAME)
62-
63-
_terminate()
64-
65-
signal.signal(signal.SIGTERM, _cleanup_files)
48+
DOCUMENTATION_LINK = "https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html"
6649

6750

6851
def get_validated_dmatrices(
@@ -169,50 +152,86 @@ def sagemaker_train(
169152
# Obtain information about training resources to determine which distributed setup to use, if needed.
170153
num_hosts = len(sm_hosts)
171154

172-
train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices(
173-
train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val
174-
)
175155
checkpoint_dir = checkpoint_config.get("LocalPath", None)
176156

177-
train_args = dict(
178-
train_cfg=validated_train_config,
179-
train_dmatrix=train_dmatrix,
180-
val_dmatrix=val_dmatrix,
181-
train_val_dmatrix=train_val_dmatrix,
182-
model_dir=model_dir,
183-
checkpoint_dir=checkpoint_dir,
184-
)
157+
num_gpus = int(os.getenv(SM_NUM_GPUS, 0))
158+
logging.info(f"Determined {num_gpus} GPU(s) available on the instance.")
159+
tree_method_hp = validated_train_config.get("tree_method")
160+
161+
is_dask_job = validated_train_config.pop("use_dask_gpu_training", "false")
185162

186-
if num_hosts > 1:
187-
# Wait for hosts to find each other
188-
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts))
189-
distributed.wait_hostname_resolution(sm_hosts)
163+
if is_dask_job == "true":
164+
gpu_train_validation_errors = distributed_gpu_training.validate_gpu_train_configuration(
165+
tree_method_hp=tree_method_hp,
166+
num_hosts=num_hosts,
167+
num_gpus=num_gpus,
168+
input_mode=input_mode,
169+
input_format=file_type,
170+
data_config=validated_data_config,
171+
)
190172

191-
if not train_dmatrix:
173+
if gpu_train_validation_errors:
174+
raise exc.UserError(f"Some configurations unsuitable for Dask GPU training were found: "
175+
f"{'. '.join(gpu_train_validation_errors)}")
176+
177+
logging.info("Going to run distributed GPU training through Dask.")
178+
distributed_gpu_training.run_training_with_dask(
179+
hyperparameters=validated_train_config,
180+
train_path=train_path,
181+
validation_path=val_path,
182+
model_dir=model_dir,
183+
content_type=file_type,
184+
sm_hosts=sm_hosts,
185+
current_host=sm_current_host,
186+
checkpoint_dir=checkpoint_dir,
187+
num_gpus=num_gpus,
188+
)
189+
else:
190+
if num_gpus > 1:
192191
logging.warning(
193-
"Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
194-
" training.".format(sm_current_host)
192+
f"If you're using GPU training, not all GPUs on the instance will be used. "
193+
f"See how to use all GPUs at {DOCUMENTATION_LINK}"
195194
)
196-
distributed.rabit_run(
197-
exec_fun=train_job,
198-
args=train_args,
199-
include_in_training=(train_dmatrix is not None),
200-
hosts=sm_hosts,
201-
current_host=sm_current_host,
202-
update_rabit_args=True,
195+
196+
train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices(
197+
train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val
203198
)
204-
elif num_hosts == 1:
205-
if train_dmatrix:
206-
if validation_channel:
207-
if not val_dmatrix:
199+
train_args = dict(
200+
train_cfg=validated_train_config,
201+
train_dmatrix=train_dmatrix,
202+
val_dmatrix=val_dmatrix,
203+
train_val_dmatrix=train_val_dmatrix,
204+
model_dir=model_dir,
205+
checkpoint_dir=checkpoint_dir,
206+
)
207+
if num_hosts > 1:
208+
# Wait for hosts to find each other
209+
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts))
210+
distributed.wait_hostname_resolution(sm_hosts)
211+
if not train_dmatrix:
212+
logging.warning(
213+
"Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
214+
" training.".format(sm_current_host)
215+
)
216+
distributed.rabit_run(
217+
exec_fun=train_job,
218+
args=train_args,
219+
include_in_training=(train_dmatrix is not None),
220+
hosts=sm_hosts,
221+
current_host=sm_current_host,
222+
update_rabit_args=True,
223+
)
224+
elif num_hosts == 1:
225+
if train_dmatrix:
226+
if validation_channel and not val_dmatrix:
208227
raise exc.UserError("No data in validation channel path {}".format(val_path))
209-
logging.info("Single node training.")
210-
train_args.update({"is_master": True})
211-
train_job(**train_args)
228+
logging.info("Single node training.")
229+
train_args.update({"is_master": True})
230+
train_job(**train_args)
231+
else:
232+
raise exc.UserError("No data in training channel path {}".format(train_path))
212233
else:
213-
raise exc.UserError("No data in training channel path {}".format(train_path))
214-
else:
215-
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1")
234+
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1")
216235

217236

218237
def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_dir, checkpoint_dir, is_master):
@@ -259,11 +278,12 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
259278

260279
try:
261280
kfold = train_cfg.pop("_kfold", None)
281+
watchlist = [(train_dmatrix, "train")]
282+
if val_dmatrix is not None:
283+
watchlist.append((val_dmatrix, "validation"))
262284

263285
if kfold is None:
264-
xgb_model, iteration, callbacks, watchlist = get_callbacks_watchlist(
265-
train_dmatrix=train_dmatrix,
266-
val_dmatrix=val_dmatrix,
286+
xgb_model, iteration, callbacks = get_callbacks(
267287
model_dir=model_dir,
268288
checkpoint_dir=checkpoint_dir,
269289
early_stopping_data_name=early_stopping_data_name,
@@ -322,9 +342,7 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
322342
cv_train_dmatrix = train_val_dmatrix.slice(train_idx)
323343
cv_val_dmatrix = train_val_dmatrix.slice(val_idx)
324344

325-
xgb_model, iteration, callbacks, watchlist = get_callbacks_watchlist(
326-
train_dmatrix=cv_train_dmatrix,
327-
val_dmatrix=cv_val_dmatrix,
345+
xgb_model, iteration, callbacks = get_callbacks(
328346
model_dir=model_dir,
329347
checkpoint_dir=checkpoint_dir,
330348
early_stopping_data_name=early_stopping_data_name,
@@ -391,61 +409,6 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
391409
logging.debug("Stored trained model {} at {}".format(fold, model_location))
392410

393411

394-
def get_callbacks_watchlist(
395-
train_dmatrix,
396-
val_dmatrix,
397-
model_dir,
398-
checkpoint_dir,
399-
early_stopping_data_name,
400-
early_stopping_metric,
401-
early_stopping_rounds,
402-
save_model_on_termination,
403-
is_master,
404-
fold=None,
405-
):
406-
if checkpoint_dir and fold is not None:
407-
checkpoint_dir = os.path.join(checkpoint_dir, f"model-{fold}")
408-
409-
# Set callbacks
410-
xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir)
411-
if xgb_model is not None:
412-
if fold is not None:
413-
xgb_model = f"{xgb_model}-{fold}"
414-
logging.info("Checkpoint loaded from %s", xgb_model)
415-
logging.info("Resuming from iteration %s", iteration)
416-
417-
callbacks = []
418-
callbacks.append(xgb.callback.EvaluationMonitor())
419-
if checkpoint_dir:
420-
save_checkpoint = xgb.callback.TrainingCheckPoint(
421-
directory=checkpoint_dir, iterations=iteration, name=checkpointing.CHECKPOINT_FILENAME
422-
)
423-
callbacks.append(save_checkpoint)
424-
425-
if save_model_on_termination == "true":
426-
model_name = f"{MODEL_NAME}-{fold}" if fold is not None else MODEL_NAME
427-
save_intermediate_model = checkpointing.SaveIntermediateModelCallBack(model_dir, model_name, is_master)
428-
callbacks.append(save_intermediate_model)
429-
add_sigterm_handler(model_dir, is_master)
430-
431-
if early_stopping_data_name and early_stopping_metric and early_stopping_rounds:
432-
maximize = early_stopping_metric in XGB_MAXIMIZE_METRICS
433-
early_stop = xgb.callback.EarlyStopping(
434-
rounds=early_stopping_rounds,
435-
data_name=early_stopping_data_name,
436-
metric_name=early_stopping_metric,
437-
maximize=maximize,
438-
save_best=True,
439-
)
440-
callbacks.append(early_stop)
441-
442-
watchlist = [(train_dmatrix, "train")]
443-
if val_dmatrix is not None:
444-
watchlist.append((val_dmatrix, "validation"))
445-
446-
return xgb_model, iteration, callbacks, watchlist
447-
448-
449412
def print_cv_metric(num_round, evals_results):
450413
cv_eval_report = f"[{num_round}]"
451414
for metric_name in evals_results[0]["train"]:

src/sagemaker_xgboost_container/callback.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import logging
2-
2+
import os
3+
import signal
34
import xgboost as xgb
5+
6+
from sagemaker_xgboost_container import checkpointing
7+
from sagemaker_xgboost_container.algorithm_mode import train_utils
8+
from sagemaker_xgboost_container.constants.xgb_constants import MODEL_NAME, XGB_MAXIMIZE_METRICS
49
from smdebug.xgboost import Hook
510

611
logger = logging.getLogger(__name__)
@@ -45,3 +50,73 @@ def add_debugging(callbacks, hyperparameters, train_dmatrix, val_dmatrix=None, j
4550
logging.debug("Failed to create debug hook", e)
4651
else:
4752
callbacks.append(hook)
53+
54+
55+
def add_sigterm_handler(model_dir, is_master):
56+
"""Stop training and cleanup model directory when SIGTERM is received.
57+
58+
Model directory is only cleaned if is_master is True. Otherwise program terminates.
59+
60+
:param model_dir: Directory where model is saved
61+
:param is_master: True if single node training, or the current node is the master node in distributed training
62+
"""
63+
64+
def _terminate():
65+
os._exit(0)
66+
67+
def _cleanup_files(signo, frame):
68+
if is_master:
69+
train_utils.cleanup_dir(model_dir, MODEL_NAME)
70+
71+
_terminate()
72+
73+
signal.signal(signal.SIGTERM, _cleanup_files)
74+
75+
76+
def get_callbacks(
77+
model_dir,
78+
checkpoint_dir,
79+
early_stopping_data_name,
80+
early_stopping_metric,
81+
early_stopping_rounds,
82+
save_model_on_termination,
83+
is_master,
84+
fold=None,
85+
):
86+
if checkpoint_dir and fold is not None:
87+
checkpoint_dir = os.path.join(checkpoint_dir, f"model-{fold}")
88+
89+
# Set callbacks
90+
xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir)
91+
if xgb_model is not None:
92+
if fold is not None:
93+
xgb_model = f"{xgb_model}-{fold}"
94+
logging.info("Checkpoint loaded from %s", xgb_model)
95+
logging.info("Resuming from iteration %s", iteration)
96+
97+
callbacks = []
98+
callbacks.append(xgb.callback.EvaluationMonitor())
99+
if checkpoint_dir:
100+
save_checkpoint = xgb.callback.TrainingCheckPoint(
101+
directory=checkpoint_dir, iterations=iteration, name=checkpointing.CHECKPOINT_FILENAME
102+
)
103+
callbacks.append(save_checkpoint)
104+
105+
if save_model_on_termination == "true":
106+
model_name = f"{MODEL_NAME}-{fold}" if fold is not None else MODEL_NAME
107+
save_intermediate_model = checkpointing.SaveIntermediateModelCallBack(model_dir, model_name, is_master)
108+
callbacks.append(save_intermediate_model)
109+
add_sigterm_handler(model_dir, is_master)
110+
111+
if early_stopping_data_name and early_stopping_metric and early_stopping_rounds:
112+
maximize = early_stopping_metric in XGB_MAXIMIZE_METRICS
113+
early_stop = xgb.callback.EarlyStopping(
114+
rounds=early_stopping_rounds,
115+
data_name=early_stopping_data_name,
116+
metric_name=early_stopping_metric,
117+
maximize=maximize,
118+
save_best=True,
119+
)
120+
callbacks.append(early_stop)
121+
122+
return xgb_model, iteration, callbacks

src/sagemaker_xgboost_container/constants/xgb_constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,7 @@
9393
MULTI_SOFTPROB = "multi:softprob"
9494

9595
MODEL_NAME = "xgboost-model"
96+
GPU_TREE_METHOD = "gpu_hist"
97+
98+
FULLY_REPLICATED = "FullyReplicated"
99+
PIPE_MODE = "Pipe"

0 commit comments

Comments
 (0)