Skip to content

Commit c08847d

Browse files
committed
Add ModelTrainer updates
- Used latest code in commit: 9f70fb2#diff-6643c001ac6e4e110393f1a33700adf2054cc594e5ff1e3e2630131d2c6c0551
1 parent 2adc4bd commit c08847d

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

sagemaker-train/src/sagemaker/train/modules/model_trainer.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -373,28 +373,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
373373
"If 'requirements' or 'entry_script' is provided in 'source_code', "
374374
+ "'source_dir' must also be provided.",
375375
)
376-
if not _is_valid_path(source_dir, path_type="Directory"):
376+
if not (
377+
_is_valid_path(source_dir, path_type="Directory")
378+
or _is_valid_s3_uri(source_dir, path_type="Directory")
379+
or (
380+
_is_valid_path(source_dir, path_type="File")
381+
and source_dir.endswith(".tar.gz")
382+
)
383+
or (
384+
_is_valid_s3_uri(source_dir, path_type="File")
385+
and source_dir.endswith(".tar.gz")
386+
)
387+
):
377388
raise ValueError(
378-
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
389+
f"Invalid 'source_dir' path: {source_dir}. "
390+
"Must be a valid local directory, "
391+
"s3 uri or path to tar.gz file stored locally or in s3."
379392
)
380393
if requirements:
381-
if not _is_valid_path(
382-
f"{source_dir}/{requirements}",
383-
path_type="File",
384-
):
385-
raise ValueError(
386-
f"Invalid 'requirements': {requirements}. "
387-
+ "Must be a valid file within the 'source_dir'.",
388-
)
394+
if not source_dir.endswith(".tar.gz"):
395+
if not _is_valid_path(
396+
f"{source_dir}/{requirements}", path_type="File"
397+
) and not _is_valid_s3_uri(
398+
f"{source_dir}/{requirements}", path_type="File"
399+
):
400+
raise ValueError(
401+
f"Invalid 'requirements': {requirements}. "
402+
"Must be a valid file within the 'source_dir'.",
403+
)
389404
if entry_script:
390-
if not _is_valid_path(
391-
f"{source_dir}/{entry_script}",
392-
path_type="File",
393-
):
394-
raise ValueError(
395-
f"Invalid 'entry_script': {entry_script}. "
396-
+ "Must be a valid file within the 'source_dir'.",
397-
)
405+
if not source_dir.endswith(".tar.gz"):
406+
if not _is_valid_path(
407+
f"{source_dir}/{entry_script}", path_type="File"
408+
) and not _is_valid_s3_uri(
409+
f"{source_dir}/{entry_script}", path_type="File"
410+
):
411+
raise ValueError(
412+
f"Invalid 'entry_script': {entry_script}. "
413+
"Must be a valid file within the 'source_dir'.",
414+
)
398415

399416
def model_post_init(self, __context: Any):
400417
"""Post init method to perform custom validation and set default values."""

0 commit comments

Comments
 (0)