Skip to content

Commit 9fb6a3d

Browse files
authored
Multi-GPU training (#366)
* Added code to run multi-node, multi-GPU training with Dask.
1 parent b1a3cf0 commit 9fb6a3d

File tree

15 files changed

+315
-9
lines changed

15 files changed

+315
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ __pycache__
88
.coverage*
99
.mypy_cache/
1010
.idea/
11+
.DS_Store

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Pillow==9.1.1
44
boto3==1.17.52
55
botocore==1.20.52
66
cryptography==35.0.0
7+
dask==2022.11.1
8+
dask-cuda==22.12.0
79
gunicorn==19.10.0
810
itsdangerous==2.0.1
911
matplotlib==3.4.1

src/sagemaker_xgboost_container/algorithm_mode/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker_xgboost_container.constants.sm_env_constants import SM_OUTPUT_DATA_DIR
3030
from sagemaker_xgboost_container.constants.xgb_constants import (
3131
CUSTOMER_ERRORS,
32+
MODEL_NAME,
3233
XGB_MAXIMIZE_METRICS,
3334
)
3435
from sagemaker_xgboost_container.data_utils import (
@@ -40,8 +41,6 @@
4041
)
4142
from sagemaker_xgboost_container.prediction_utils import ValidationPredictionRecorder
4243

43-
MODEL_NAME = "xgboost-model"
44-
4544
logger = logging.getLogger(__name__)
4645

4746

@@ -157,10 +156,6 @@ def sagemaker_train(
157156

158157
validation_channel = validated_data_config.get("validation", None)
159158
combine_train_val = "_kfold" in validated_train_config
160-
train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices(
161-
train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val
162-
)
163-
checkpoint_dir = checkpoint_config.get("LocalPath", None)
164159
if val_path is not None:
165160
if train_path == val_path or os.path.basename(train_path) == os.path.basename(val_path):
166161
logger.warning(
@@ -170,6 +165,15 @@ def sagemaker_train(
170165
elif not is_pipe:
171166
# Check if there is potential data redundancy between training and validation sets
172167
check_data_redundancy(train_path, val_path)
168+
169+
# Obtain information about training resources to determine which distributed setup to use, if needed.
170+
num_hosts = len(sm_hosts)
171+
172+
train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices(
173+
train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val
174+
)
175+
checkpoint_dir = checkpoint_config.get("LocalPath", None)
176+
173177
train_args = dict(
174178
train_cfg=validated_train_config,
175179
train_dmatrix=train_dmatrix,
@@ -179,9 +183,6 @@ def sagemaker_train(
179183
checkpoint_dir=checkpoint_dir,
180184
)
181185

182-
# Obtain information about training resources to determine whether to set up Rabit or not
183-
num_hosts = len(sm_hosts)
184-
185186
if num_hosts > 1:
186187
# Wait for hosts to find each other
187188
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts))

src/sagemaker_xgboost_container/constants/sm_env_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Resource related constants
1616
SM_CURRENT_HOST = "SM_CURRENT_HOST"
1717
SM_HOSTS = "SM_HOSTS"
18+
SM_NUM_GPUS = "SM_NUM_GPUS"
1819

1920
# Data related constants
2021
SM_CHANNEL_TRAIN = "SM_CHANNEL_TRAIN"

src/sagemaker_xgboost_container/constants/xgb_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,5 @@
9191
BINARY_HINGE = "binary:hinge"
9292
MULTI_SOFTMAX = "multi:softmax"
9393
MULTI_SOFTPROB = "multi:softprob"
94+
95+
MODEL_NAME = "xgboost-model"

src/sagemaker_xgboost_container/distributed_gpu/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import socket
15+
from subprocess import Popen
16+
17+
from dask.distributed import Client
18+
19+
from sagemaker_algorithm_toolkit.exceptions import AlgorithmError, PlatformError
20+
21+
SCHEDULER_EXEC_PATH = "/miniconda3/bin/dask-scheduler"
22+
CUDA_WORKER_EXEC_PATH = "/miniconda3/bin/dask-cuda-worker"
23+
24+
SCHEDULER_CONN_TIMEOUT = "20s"
25+
26+
27+
def start_daemons_in_current_instance(scheduler_address: str, is_scheduler_host: bool):
28+
# Dask distributed scheduler API doc: https://docs.dask.org/en/stable/deploying-cli.html
29+
scheduler_cli_command = [SCHEDULER_EXEC_PATH, "--no-dashboard"]
30+
scheduler_conn_string = f"tcp://{scheduler_address}"
31+
# Dask cuda worker API doc: https://docs.rapids.ai/api/dask-cuda/nightly/api.html
32+
worker_cli_command = [CUDA_WORKER_EXEC_PATH, scheduler_conn_string, "--no-dashboard"]
33+
if is_scheduler_host:
34+
Popen(scheduler_cli_command)
35+
try:
36+
# Ensure that the scheduler is up before starting workers.
37+
with Client(scheduler_address, timeout=SCHEDULER_CONN_TIMEOUT):
38+
Popen(worker_cli_command)
39+
except TimeoutError as e:
40+
raise AlgorithmError(
41+
f"Couldn't connect to scheduler after {SCHEDULER_CONN_TIMEOUT}. Please try re-running the training job."
42+
f" Exception: {e}"
43+
)
44+
45+
46+
def get_host_ip(host_name: str) -> str:
47+
try:
48+
host_ip = socket.gethostbyname(host_name)
49+
except socket.gaierror as e:
50+
# This shouldn't have happened, and it's not the user's fault.
51+
raise PlatformError(f"Failed hostname resolution for host '{host_name}', exception: {e}")
52+
return host_ip
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import os
15+
16+
import dask.dataframe as dask_dataframe
17+
from dask.dataframe import DataFrame, Series
18+
from dask.distributed import Client, wait
19+
from xgboost.dask import DaskDMatrix
20+
21+
from sagemaker_algorithm_toolkit.exceptions import AlgorithmError, UserError
22+
from sagemaker_xgboost_container.data_utils import CSV, PARQUET
23+
24+
25+
def _read_data(local_path: str, content_type: str) -> (DataFrame, Series):
26+
if content_type == CSV:
27+
dataframe = dask_dataframe.read_csv(os.path.join(local_path, "*.csv"), header=None)
28+
elif content_type == PARQUET:
29+
dataframe = dask_dataframe.read_parquet(local_path)
30+
else:
31+
raise UserError(f"Unexpected content type '{content_type}'. Supported content types are CSV and PARQUET.")
32+
33+
target_column = dataframe.columns[0]
34+
labels = dataframe[target_column]
35+
features = dataframe[dataframe.columns.difference([target_column])]
36+
37+
return features, labels
38+
39+
40+
def get_dataframe_dimensions(dataframe: DataFrame) -> (int, int):
41+
df_shape = dataframe.shape
42+
# Note that dataframe.shape[0].compute() is an expensive operation.
43+
rows = df_shape[0].compute()
44+
cols = df_shape[1]
45+
return rows, cols
46+
47+
48+
def load_data_into_memory(client: Client, local_data_path: str, content_type: str) -> (DataFrame, Series):
49+
try:
50+
features, labels = _read_data(local_data_path, content_type)
51+
# Due to the lazy nature of Dask collections,
52+
# most data related errors will likely show up once data load is started here.
53+
features, labels = client.persist([features, labels])
54+
wait([features, labels])
55+
except Exception as e:
56+
raise UserError(f"Failed to load data. Exception: {e}")
57+
return features, labels
58+
59+
60+
def create_dask_dmatrix(client: Client, features: DataFrame, labels: Series) -> DaskDMatrix:
61+
try:
62+
dmatrix = DaskDMatrix(client, features, labels)
63+
except Exception as e:
64+
raise AlgorithmError(f"Failed to create DaskDMatrix with given data. Exception: {e}")
65+
return dmatrix
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import logging
15+
import os
16+
import socket
17+
import time
18+
from typing import Dict
19+
20+
import xgboost as xgb
21+
from dask.distributed import Client
22+
23+
from sagemaker_algorithm_toolkit import exceptions as exc
24+
from sagemaker_xgboost_container.constants.xgb_constants import MODEL_NAME
25+
from sagemaker_xgboost_container.distributed_gpu.dask_cluster_utils import (
26+
get_host_ip,
27+
start_daemons_in_current_instance,
28+
)
29+
from sagemaker_xgboost_container.distributed_gpu.dask_data_utils import (
30+
create_dask_dmatrix,
31+
get_dataframe_dimensions,
32+
load_data_into_memory,
33+
)
34+
35+
logger = logging.getLogger(__name__)
36+
37+
SCHEDULER_PORT = "8786"
38+
WAIT_FOR_ALL_WORKERS_TIMEOUT_SEC = 20
39+
WORKER_STAY_ALIVE_CHECK_FREQ_SEC = 10
40+
41+
42+
def run_training_with_dask(
43+
hyperparameters: Dict,
44+
train_path: str,
45+
validation_path: str,
46+
model_dir: str,
47+
content_type: str,
48+
sm_hosts: [str],
49+
current_host: str,
50+
num_gpus: int,
51+
):
52+
scheduler_host = sm_hosts[0]
53+
scheduler_host_ip = get_host_ip(scheduler_host)
54+
55+
scheduler_address = f"{scheduler_host_ip}:{SCHEDULER_PORT}"
56+
is_scheduler_host = current_host == scheduler_host
57+
58+
start_daemons_in_current_instance(scheduler_address, is_scheduler_host)
59+
60+
total_num_workers = len(sm_hosts) * num_gpus
61+
62+
# We only need to submit the job from one instance.
63+
if is_scheduler_host:
64+
with Client(scheduler_address) as client:
65+
# We ensure that all workers are present before proceeding.
66+
client.wait_for_workers(total_num_workers, WAIT_FOR_ALL_WORKERS_TIMEOUT_SEC)
67+
68+
logging.info("Starting to read training data...")
69+
watchlist = []
70+
71+
X_train, y_train = load_data_into_memory(client, train_path, content_type)
72+
73+
dtrain = create_dask_dmatrix(client, X_train, y_train)
74+
75+
# Log train data dimension for sanity check.
76+
train_num_rows, train_num_cols = get_dataframe_dimensions(X_train)
77+
logging.info(f"Train features matrix has {train_num_rows} rows and {train_num_cols} columns")
78+
79+
watchlist.append((dtrain, "train"))
80+
81+
if validation_path is not None:
82+
X_valid, y_valid = load_data_into_memory(client, validation_path, content_type)
83+
dvalid = create_dask_dmatrix(client, X_valid, y_valid)
84+
watchlist.append((dvalid, "validation"))
85+
86+
logging.info("Data load complete. Starting training...")
87+
88+
try:
89+
output = xgb.dask.train(
90+
client, hyperparameters, dtrain, num_boost_round=hyperparameters["num_round"], evals=watchlist
91+
)
92+
booster = output["booster"]
93+
94+
logging.info("Training complete. Saving model...")
95+
booster.save_model(os.path.join(model_dir, MODEL_NAME))
96+
except Exception as e:
97+
exception_prefix = "XGB train call failed with exception"
98+
raise exc.AlgorithmError(f"{exception_prefix}:\n {str(e)}")
99+
100+
logging.info("Terminating cluster...")
101+
102+
else:
103+
scheduler = (scheduler_host_ip, int(SCHEDULER_PORT))
104+
# Do not exit till the job is done.
105+
while True:
106+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as alive_socket:
107+
alive_check = alive_socket.connect_ex(scheduler)
108+
if alive_check != 0:
109+
logging.info("Received a shutdown signal from scheduler. Exiting...")
110+
break
111+
time.sleep(WORKER_STAY_ALIVE_CHECK_FREQ_SEC)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
0,1,0,0,0,0
2+
0,1,0,0,0,0
3+
0,1,0,0,0,0
4+
0,1,0,0,0,0
5+
1,0,1,0,0,0

0 commit comments

Comments
 (0)