Skip to content

Commit 1c7faf0

Browse files
Migrate recent commits in master-v2 (#5423)
* Update image_uri_config, fw_utils and image_uris.py in sagemaker-core * Add ModelTrainer updates - Used latest code in commit: 9f70fb2#diff-6643c001ac6e4e110393f1a33700adf2054cc594e5ff1e3e2630131d2c6c0551 * Update s3 bucket check in session_helper.py Code change is based on commit: 903cb8a * fix: Map llama models to correct script Based on commit: aws/sagemaker-python-sdk-staging@67a3e5a * fix: honor json serialization of HPs aws/sagemaker-python-sdk-staging@246d560 * fix: clarify model monitor one time schedule bug From commit: ddc54d2 * fix: Allow import failure for internal _hashlib module From commit: aws/sagemaker-python-sdk-staging@5198f28 * Remove duplicate model_trainer.py * Add ignore_patterns in ModelTrainer to ignore specific files/folders For commit: 829030a * Update instance type regex to also include hyphens For commit: aws/sagemaker-python-sdk-staging@824675b * chore: domain support for eu-isoe-west-1 For commit: d0bd4f7 * Fix: Object of type ModelLifeCycle is not JSON serializable For commit: 844b558 * fix: sanitize git clone repo input url For commit: aws/sagemaker-python-sdk-staging@ed143b7 * Add support for MetricDefinitions in ModelTrainer For commit: 0215512 * feat: support pipeline versioning For commit: aws/sagemaker-python-sdk-staging@9bfe85a * add eval custom lambda arn to hyperparameter For commit: aws/sagemaker-python-sdk-staging@bcd5348 * Add Numpy 2.0 support For commit: aws/sagemaker-python-sdk-staging@99210b2 Tested by running sagemaker-serve unit tests * fix: update get_execution_role to directly return the ExecutionRoleArn if it presents in the resource metadata file For commit: aws/sagemaker-python-sdk-staging@b9df334 * HF Optimum Neuron 0.4.1 DLCs For commit: 5d3f175 * Fix import error * Fix llama_v3 in sm_recipes * Remove duplicate json in image_retriever * Add todo notes in pipeline class * Add V2 image_config_url unit tests --------- Co-authored-by: aviruthen <91846056+aviruthen@users.noreply.github.com>
1 parent 8695cca commit 1c7faf0

File tree

194 files changed

+3859
-46337
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

194 files changed

+3859
-46337
lines changed

requirements/extras/test_requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ pytest-xdist
44
mock
55
pydantic==2.11.9
66
pydantic_core==2.33.2
7-
pandas
7+
pandas>=2.3.0
8+
numpy>=2.0.0, <3.0
9+
scikit-learn==1.6.1
810
scipy
911
omegaconf
1012
graphene
11-
typing_extensions>=4.9.0
13+
typing_extensions>=4.9.0
14+
tensorflow>=2.16.2,<=2.19.0

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"us-isob-east-1": "sc2s.sgov.gov",
6060
"us-isof-south-1": "csp.hci.ic.gov",
6161
"us-isof-east-1": "csp.hci.ic.gov",
62+
"eu-isoe-west-1": "cloud.adc-e.uk",
6263
}
6364

6465
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -1555,7 +1556,7 @@ def get_instance_type_family(instance_type: str) -> str:
15551556
"""
15561557
instance_type_family = ""
15571558
if isinstance(instance_type, str):
1558-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1559+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
15591560
if match is not None:
15601561
instance_type_family = match[1]
15611562
return instance_type_family

sagemaker-core/src/sagemaker/core/fw_utils.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525

2626
from packaging import version
2727

28-
import sagemaker.core.common_utils as sagemaker_utils
29-
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
28+
import sagemaker.core.common_utils as utils
29+
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
3030
from sagemaker.core.instance_group import InstanceGroup
31-
from sagemaker.core.s3 import s3_path_join
31+
from sagemaker.core.s3.utils import s3_path_join
3232
from sagemaker.core.session_settings import SessionSettings
3333
from sagemaker.core.workflow import is_pipeline_variable
34-
from sagemaker.core.helper.pipeline_variable import PipelineVariable
34+
from sagemaker.core.workflow.entities import PipelineVariable
3535

3636
logger = logging.getLogger(__name__)
3737

@@ -155,6 +155,9 @@
155155
"2.3.1",
156156
"2.4.1",
157157
"2.5.1",
158+
"2.6.0",
159+
"2.7.1",
160+
"2.8.0",
158161
]
159162

160163
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -455,7 +458,7 @@ def tar_and_upload_dir(
455458

456459
try:
457460
source_files = _list_files_to_compress(script, directory) + dependencies
458-
tar_file = sagemaker_utils.create_tar_file(
461+
tar_file = utils.create_tar_file(
459462
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
460463
)
461464

@@ -516,7 +519,7 @@ def framework_name_from_image(image_uri):
516519
- str: The image tag
517520
- str: If the TensorFlow image is script mode
518521
"""
519-
sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
522+
sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
520523
sagemaker_match = sagemaker_pattern.match(image_uri)
521524
if sagemaker_match is None:
522525
return None, None, None, None
@@ -595,7 +598,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
595598
"""
596599
name_from_image = f"/model_code/{int(time.time())}"
597600
if not is_pipeline_variable(image):
598-
name_from_image = sagemaker_utils.name_from_image(image)
601+
name_from_image = utils.name_from_image(image)
599602
return s3_path_join(code_location_key_prefix, model_name or name_from_image)
600603

601604

@@ -961,7 +964,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
961964
"""
962965
err_msg = ""
963966
if isinstance(instance_type, str):
964-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
967+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
965968
if match and match[1].startswith("trn"):
966969
keys = list(distribution.keys())
967970
if len(keys) == 0:
@@ -1062,7 +1065,7 @@ def validate_torch_distributed_distribution(
10621065
)
10631066

10641067
# Check entry point type
1065-
if not entry_point.endswith(".py"):
1068+
if entry_point is not None and not entry_point.endswith(".py"):
10661069
err_msg += (
10671070
"Unsupported entry point type for the distribution torch_distributed.\n"
10681071
"Only python programs (*.py) are supported."
@@ -1082,7 +1085,7 @@ def _is_gpu_instance(instance_type):
10821085
bool: Whether or not the instance_type supports GPU
10831086
"""
10841087
if isinstance(instance_type, str):
1085-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1088+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
10861089
if match:
10871090
if match[1].startswith("p") or match[1].startswith("g"):
10881091
return True
@@ -1101,7 +1104,7 @@ def _is_trainium_instance(instance_type):
11011104
bool: Whether or not the instance_type is a Trainium instance
11021105
"""
11031106
if isinstance(instance_type, str):
1104-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1107+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11051108
if match and match[1].startswith("trn"):
11061109
return True
11071110
return False
@@ -1148,7 +1151,7 @@ def _instance_type_supports_profiler(instance_type):
11481151
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
11491152
"""
11501153
if isinstance(instance_type, str):
1151-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1154+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
11521155
if match and match[1].startswith("trn"):
11531156
return True
11541157
return False
@@ -1174,3 +1177,44 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11741177
"framework_version or py_version was None, yet image_uri was also None. "
11751178
"Either specify both framework_version and py_version, or specify image_uri."
11761179
)
1180+
1181+
1182+
def create_image_uri(
1183+
region,
1184+
framework,
1185+
instance_type,
1186+
framework_version,
1187+
py_version=None,
1188+
account=None, # pylint: disable=W0613
1189+
accelerator_type=None,
1190+
optimized_families=None, # pylint: disable=W0613
1191+
):
1192+
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
1193+
1194+
Args:
1195+
region (str): AWS region where the image is uploaded.
1196+
framework (str): framework used by the image.
1197+
instance_type (str): SageMaker instance type. Used to determine device
1198+
type (cpu/gpu/family-specific optimized).
1199+
framework_version (str): The version of the framework.
1200+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1201+
If not specified, image uri will not include a python component.
1202+
account (str): AWS account that contains the image. (default:
1203+
'520713654638')
1204+
accelerator_type (str): SageMaker Elastic Inference accelerator type.
1205+
optimized_families (str): Deprecated. A no-op argument.
1206+
1207+
Returns:
1208+
the image uri
1209+
"""
1210+
from sagemaker.core import image_uris
1211+
1212+
renamed_warning("The method create_image_uri")
1213+
return image_uris.retrieve(
1214+
framework=framework,
1215+
region=region,
1216+
version=framework_version,
1217+
py_version=py_version,
1218+
instance_type=instance_type,
1219+
accelerator_type=accelerator_type,
1220+
)

sagemaker-core/src/sagemaker/core/git_utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,69 @@
2020
import warnings
2121
import six
2222
from six.moves import urllib
23+
import re
24+
from pathlib import Path
25+
from urllib.parse import urlparse
26+
27+
def _sanitize_git_url(repo_url):
28+
"""Sanitize Git repository URL to prevent URL injection attacks.
29+
30+
Args:
31+
repo_url (str): The Git repository URL to sanitize
2332
33+
Returns:
34+
str: The sanitized URL
35+
36+
Raises:
37+
ValueError: If the URL contains suspicious patterns that could indicate injection
38+
"""
39+
at_count = repo_url.count("@")
40+
41+
if repo_url.startswith("git@"):
42+
# git@ format requires exactly one @
43+
if at_count != 1:
44+
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
45+
elif repo_url.startswith("ssh://"):
46+
# ssh:// format can have 0 or 1 @ symbols
47+
if at_count > 1:
48+
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
49+
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
50+
# HTTPS format allows 0 or 1 @ symbols
51+
if at_count > 1:
52+
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")
53+
54+
# Check for invalid characters in the URL before parsing
55+
# These characters should not appear in legitimate URLs
56+
invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
57+
for char in invalid_chars:
58+
if char in repo_url:
59+
raise ValueError("Invalid characters in hostname")
60+
61+
try:
62+
parsed = urlparse(repo_url)
63+
64+
# Check for suspicious characters in hostname that could indicate injection
65+
if parsed.hostname:
66+
# Check for URL-encoded characters that might be used for obfuscation
67+
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
68+
for pattern in suspicious_patterns:
69+
if pattern in parsed.hostname.lower():
70+
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")
71+
72+
# Validate that the hostname looks legitimate
73+
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
74+
raise ValueError("Invalid characters in hostname")
75+
76+
except Exception as e:
77+
if isinstance(e, ValueError):
78+
raise
79+
raise ValueError(f"Failed to parse URL: {str(e)}")
80+
else:
81+
raise ValueError(
82+
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
83+
)
84+
85+
return repo_url
2486

2587
def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
2688
"""Git clone repo containing the training code and serving code.
@@ -87,6 +149,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
87149
if entry_point is None:
88150
raise ValueError("Please provide an entry point.")
89151
_validate_git_config(git_config)
152+
153+
# SECURITY: Sanitize the repository URL to prevent injection attacks
154+
git_config["repo"] = _sanitize_git_url(git_config["repo"])
155+
90156
dest_dir = tempfile.mkdtemp()
91157
_generate_and_run_clone_command(git_config, dest_dir)
92158

sagemaker-core/src/sagemaker/core/helper/session_helper.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,16 @@ def get_caller_identity_arn(self):
330330
user_profile_name = metadata.get("UserProfileName")
331331
execution_role_arn = metadata.get("ExecutionRoleArn")
332332
try:
333+
# find execution role from the metadata file if present
334+
if execution_role_arn is not None:
335+
return execution_role_arn
336+
333337
if domain_id is None:
334338
instance_desc = self.sagemaker_client.describe_notebook_instance(
335339
NotebookInstanceName=instance_name
336340
)
337341
return instance_desc["RoleArn"]
338342

339-
# find execution role from the metadata file if present
340-
if execution_role_arn is not None:
341-
return execution_role_arn
342-
343343
user_profile_desc = self.sagemaker_client.describe_user_profile(
344344
DomainId=domain_id, UserProfileName=user_profile_name
345345
)
@@ -666,9 +666,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
666666
667667
"""
668668
try:
669-
s3.meta.client.head_bucket(
670-
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
671-
)
669+
if self.default_bucket_prefix:
670+
s3.meta.client.list_objects_v2(
671+
Bucket=bucket_name,
672+
Prefix=self.default_bucket_prefix,
673+
ExpectedBucketOwner=expected_bucket_owner_id,
674+
)
675+
else:
676+
s3.meta.client.head_bucket(
677+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
678+
)
672679
except ClientError as e:
673680
error_code = e.response["Error"]["Code"]
674681
message = e.response["Error"]["Message"]
@@ -699,7 +706,12 @@ def general_bucket_check_if_user_has_permission(
699706
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
700707
"""
701708
try:
702-
s3.meta.client.head_bucket(Bucket=bucket_name)
709+
if self.default_bucket_prefix:
710+
s3.meta.client.list_objects_v2(
711+
Bucket=bucket_name, Prefix=self.default_bucket_prefix
712+
)
713+
else:
714+
s3.meta.client.head_bucket(Bucket=bucket_name)
703715
except ClientError as e:
704716
error_code = e.response["Error"]["Code"]
705717
message = e.response["Error"]["Message"]

sagemaker-core/src/sagemaker/core/huggingface/__init__.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)