Skip to content

Commit 73bf599

Browse files
authored
Ran flake8, isort, black and pyupgrade (#375)
1 parent 342a097 commit 73bf599

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

src/sagemaker_xgboost_container/algorithm_mode/train.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def sagemaker_train(
128128
channels = cv.initialize()
129129
validated_data_config = channels.validate(data_config)
130130

131-
logging.debug("hyperparameters {}".format(validated_train_config))
132-
logging.debug("channels {}".format(validated_data_config))
131+
logging.debug(f"hyperparameters {validated_train_config}")
132+
logging.debug(f"channels {validated_data_config}")
133133

134134
# Get Training and Validation Data Matrices
135135
file_type = get_content_type(validated_data_config["train"].get("ContentType"))
@@ -171,8 +171,10 @@ def sagemaker_train(
171171
)
172172

173173
if gpu_train_validation_errors:
174-
raise exc.UserError(f"Some configurations unsuitable for Dask GPU training were found: "
175-
f"{'. '.join(gpu_train_validation_errors)}")
174+
raise exc.UserError(
175+
f"Some configurations unsuitable for Dask GPU training were found: "
176+
f"{'. '.join(gpu_train_validation_errors)}"
177+
)
176178

177179
logging.info("Going to run distributed GPU training through Dask.")
178180
distributed_gpu_training.run_training_with_dask(
@@ -206,7 +208,7 @@ def sagemaker_train(
206208
)
207209
if num_hosts > 1:
208210
# Wait for hosts to find each other
209-
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts))
211+
logging.info(f"Distributed node training with {num_hosts} hosts: {sm_hosts}")
210212
distributed.wait_hostname_resolution(sm_hosts)
211213
if not train_dmatrix:
212214
logging.warning(
@@ -224,12 +226,12 @@ def sagemaker_train(
224226
elif num_hosts == 1:
225227
if train_dmatrix:
226228
if validation_channel and not val_dmatrix:
227-
raise exc.UserError("No data in validation channel path {}".format(val_path))
229+
raise exc.UserError(f"No data in validation channel path {val_path}")
228230
logging.info("Single node training.")
229231
train_args.update({"is_master": True})
230232
train_job(**train_args)
231233
else:
232-
raise exc.UserError("No data in training channel path {}".format(train_path))
234+
raise exc.UserError(f"No data in training channel path {train_path}")
233235
else:
234236
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1")
235237

@@ -272,9 +274,9 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
272274
elif eval_metric:
273275
early_stopping_metric = eval_metric[-1]
274276

275-
logging.info("Train matrix has {} rows and {} columns".format(train_dmatrix.num_row(), train_dmatrix.num_col()))
277+
logging.info(f"Train matrix has {train_dmatrix.num_row()} rows and {train_dmatrix.num_col()} columns")
276278
if val_dmatrix:
277-
logging.info("Validation matrix has {} rows".format(val_dmatrix.num_row()))
279+
logging.info(f"Validation matrix has {val_dmatrix.num_row()} rows")
278280

279281
try:
280282
kfold = train_cfg.pop("_kfold", None)
@@ -360,7 +362,7 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
360362
)
361363

362364
evals_result = {}
363-
logging.info("Train cross validation fold {}".format((len(bst) % kfold) + 1))
365+
logging.info(f"Train cross validation fold {(len(bst) % kfold) + 1}")
364366
booster = xgb.train(
365367
train_cfg,
366368
cv_train_dmatrix,
@@ -377,13 +379,13 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
377379
val_pred.record(val_idx, booster.predict(cv_val_dmatrix))
378380

379381
if len(bst) % kfold == 0:
380-
logging.info("The metrics of round {} cross validation".format(int(len(bst) / kfold)))
382+
logging.info(f"The metrics of round {int(len(bst) / kfold)} cross validation")
381383
print_cv_metric(num_round, evals_results[-kfold:])
382384

383385
val_pred.save()
384386

385387
if num_cv_round > 1:
386-
logging.info("The overall metrics of {}-round cross validation".format(num_cv_round))
388+
logging.info(f"The overall metrics of {num_cv_round}-round cross validation")
387389
print_cv_metric(num_round, evals_results)
388390

389391
except Exception as e:
@@ -392,7 +394,7 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
392394
raise exc.UserError(str(e))
393395

394396
exception_prefix = "XGB train call failed with exception"
395-
raise exc.AlgorithmError("{}:\n {}".format(exception_prefix, str(e)))
397+
raise exc.AlgorithmError(f"{exception_prefix}:\n {str(e)}")
396398

397399
if not os.path.exists(model_dir):
398400
os.makedirs(model_dir)
@@ -401,18 +403,18 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_val_dmatrix, model_di
401403
if type(bst) is not list:
402404
model_location = os.path.join(model_dir, MODEL_NAME)
403405
bst.save_model(model_location)
404-
logging.debug("Stored trained model at {}".format(model_location))
406+
logging.debug(f"Stored trained model at {model_location}")
405407
else:
406408
for fold in range(len(bst)):
407409
model_location = os.path.join(model_dir, f"{MODEL_NAME}-{fold}")
408410
bst[fold].save_model(model_location)
409-
logging.debug("Stored trained model {} at {}".format(fold, model_location))
411+
logging.debug(f"Stored trained model {fold} at {model_location}")
410412

411413

412414
def print_cv_metric(num_round, evals_results):
413415
cv_eval_report = f"[{num_round}]"
414416
for metric_name in evals_results[0]["train"]:
415417
for data_name in ["train", "validation"]:
416418
metric_val = [evals_result[data_name][metric_name][-1] for evals_result in evals_results]
417-
cv_eval_report += "\t{0}-{1}:{2:.5f}".format(data_name, metric_name, np.mean(metric_val))
419+
cv_eval_report += f"\t{data_name}-{metric_name}:{np.mean(metric_val):.5f}"
418420
print(cv_eval_report)

src/sagemaker_xgboost_container/distributed_gpu/distributed_gpu_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def validate_gpu_train_configuration(
7272

7373
return all_exceptions
7474

75+
7576
def run_training_with_dask(
7677
hyperparameters: Dict,
7778
train_path: str,

test/unit/distributed_gpu/test_distributed_gpu_training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import unittest
1515

1616
from sagemaker_algorithm_toolkit import channel_validation as cv
17-
from sagemaker_algorithm_toolkit.exceptions import UserError
1817
from sagemaker_xgboost_container.distributed_gpu.distributed_gpu_training import (
1918
INPUT_FORMAT_ERROR_MSG,
2019
NON_GPU_ERROR_MSG,

0 commit comments

Comments
 (0)