Skip to content

Commit 532fe77

Browse files
committed
Add input validation and resource management improvements V3
1 parent fb0d789 commit 532fe77

File tree

5 files changed

+167
-6
lines changed

5 files changed

+167
-6
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
607607
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
608608

609609

610+
def _validate_source_directory(source_directory):
611+
"""Validate that source_directory is safe to use.
612+
613+
Ensures the source directory path does not access restricted system locations.
614+
615+
Args:
616+
source_directory (str): The source directory path to validate.
617+
618+
Raises:
619+
ValueError: If the path is not allowed.
620+
"""
621+
if not source_directory or source_directory.lower().startswith("s3://"):
622+
# S3 paths and None are safe
623+
return
624+
625+
abs_source = abspath(source_directory)
626+
627+
# Blocklist of sensitive directories that should not be accessible
628+
sensitive_paths = [
629+
abspath(os.path.expanduser("~/.aws")),
630+
abspath(os.path.expanduser("~/.ssh")),
631+
abspath(os.path.expanduser("~/.kube")),
632+
abspath(os.path.expanduser("~/.docker")),
633+
abspath(os.path.expanduser("~/.config")),
634+
abspath(os.path.expanduser("~/.credentials")),
635+
"/etc",
636+
"/root",
637+
"/home",
638+
"/var/lib",
639+
"/opt/ml/metadata",
640+
]
641+
642+
# Check if the source path is under any sensitive directory
643+
for sensitive_path in sensitive_paths:
644+
if abs_source.startswith(sensitive_path):
645+
raise ValueError(
646+
f"source_directory cannot access sensitive system paths. "
647+
f"Got: {source_directory} (resolved to {abs_source})"
648+
)
649+
650+
# Check for symlinks to prevent symlink-based escapes
651+
if os.path.islink(abs_source):
652+
raise ValueError(f"source_directory cannot be a symlink: {source_directory}")
653+
654+
655+
def _validate_dependency_path(dependency):
656+
"""Validate that a dependency path is safe to use.
657+
658+
Ensures the dependency path does not access restricted system locations.
659+
660+
Args:
661+
dependency (str): The dependency path to validate.
662+
663+
Raises:
664+
ValueError: If the path is not allowed.
665+
"""
666+
if not dependency:
667+
return
668+
669+
abs_dependency = abspath(dependency)
670+
671+
# Blocklist of sensitive directories that should not be accessible
672+
sensitive_paths = [
673+
abspath(os.path.expanduser("~/.aws")),
674+
abspath(os.path.expanduser("~/.ssh")),
675+
abspath(os.path.expanduser("~/.kube")),
676+
abspath(os.path.expanduser("~/.docker")),
677+
abspath(os.path.expanduser("~/.config")),
678+
abspath(os.path.expanduser("~/.credentials")),
679+
"/etc",
680+
"/root",
681+
"/home",
682+
"/var/lib",
683+
"/opt/ml/metadata",
684+
]
685+
686+
# Check if the dependency path is under any sensitive directory
687+
for sensitive_path in sensitive_paths:
688+
if abs_dependency.startswith(sensitive_path):
689+
raise ValueError(
690+
f"dependency path cannot access sensitive system paths. "
691+
f"Got: {dependency} (resolved to {abs_dependency})"
692+
)
693+
694+
# Check for symlinks to prevent symlink-based escapes
695+
if os.path.islink(abs_dependency):
696+
raise ValueError(f"dependency path cannot be a symlink: {dependency}")
697+
698+
610699
def _create_or_update_code_dir(
611700
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
612701
):
@@ -620,6 +709,8 @@ def _create_or_update_code_dir(
620709
custom_extractall_tarfile(t, code_dir)
621710

622711
elif source_directory:
712+
# Validate source_directory for security
713+
_validate_source_directory(source_directory)
623714
if os.path.exists(code_dir):
624715
shutil.rmtree(code_dir)
625716
shutil.copytree(source_directory, code_dir)
@@ -635,6 +726,8 @@ def _create_or_update_code_dir(
635726
raise
636727

637728
for dependency in dependencies:
729+
# Validate dependency path for security
730+
_validate_dependency_path(dependency)
638731
lib_dir = os.path.join(code_dir, "lib")
639732
if os.path.isdir(dependency):
640733
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1646,6 +1739,38 @@ def _get_safe_members(members):
16461739
yield file_info
16471740

16481741

1742+
def _validate_extracted_paths(extract_path):
1743+
"""Validate that extracted paths remain within the expected directory.
1744+
1745+
Performs post-extraction validation to ensure all extracted files and directories
1746+
are within the intended extraction path.
1747+
1748+
Args:
1749+
extract_path (str): The path where files were extracted.
1750+
1751+
Raises:
1752+
ValueError: If any extracted file is outside the expected extraction path.
1753+
"""
1754+
base = _get_resolved_path(extract_path)
1755+
1756+
for root, dirs, files in os.walk(extract_path):
1757+
# Check directories
1758+
for dir_name in dirs:
1759+
dir_path = os.path.join(root, dir_name)
1760+
resolved = _get_resolved_path(dir_path)
1761+
if not resolved.startswith(base):
1762+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
1763+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1764+
1765+
# Check files
1766+
for file_name in files:
1767+
file_path = os.path.join(root, file_name)
1768+
resolved = _get_resolved_path(file_path)
1769+
if not resolved.startswith(base):
1770+
logger.error("Extracted file escaped extraction path: %s", file_path)
1771+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
1772+
1773+
16491774
def custom_extractall_tarfile(tar, extract_path):
16501775
"""Extract a tarfile, optionally using data_filter if available.
16511776
@@ -1666,6 +1791,8 @@ def custom_extractall_tarfile(tar, extract_path):
16661791
tar.extractall(path=extract_path, filter="data")
16671792
else:
16681793
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1794+
# Re-validate extracted paths to catch symlink race conditions
1795+
_validate_extracted_paths(extract_path)
16691796

16701797

16711798
def can_model_package_source_uri_autopopulate(source_uri: str):

sagemaker-core/src/sagemaker/core/iterators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def __next__(self):
114114
class LineIterator(BaseIterator):
115115
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
116116

117+
# Maximum buffer size to prevent unbounded memory consumption (10 MB)
118+
MAX_BUFFER_SIZE = 10 * 1024 * 1024
119+
117120
def __init__(self, event_stream):
118121
"""Initialises a LineIterator Iterator object
119122
@@ -182,5 +185,15 @@ def __next__(self):
182185
# print and move on to next response byte
183186
print("Unknown event type:" + chunk)
184187
continue
188+
189+
# Check buffer size before writing to prevent unbounded memory consumption
190+
chunk_size = len(chunk["PayloadPart"]["Bytes"])
191+
current_size = self.buffer.getbuffer().nbytes
192+
if current_size + chunk_size > self.MAX_BUFFER_SIZE:
193+
raise RuntimeError(
194+
f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. "
195+
f"No newline found in stream."
196+
)
197+
185198
self.buffer.seek(0, io.SEEK_END)
186199
self.buffer.write(chunk["PayloadPart"]["Bytes"])

sagemaker-core/src/sagemaker/core/local/data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,34 @@ def get_root_dir(self):
116116
class LocalFileDataSource(DataSource):
117117
"""Represents a data source within the local filesystem."""
118118

119+
# Blocklist of sensitive directories that should not be accessible
120+
RESTRICTED_PATHS = [
121+
os.path.abspath(os.path.expanduser("~/.aws")),
122+
os.path.abspath(os.path.expanduser("~/.ssh")),
123+
os.path.abspath(os.path.expanduser("~/.kube")),
124+
os.path.abspath(os.path.expanduser("~/.docker")),
125+
os.path.abspath(os.path.expanduser("~/.config")),
126+
os.path.abspath(os.path.expanduser("~/.credentials")),
127+
"/etc",
128+
"/root",
129+
"/home",
130+
"/var/lib",
131+
"/opt/ml/metadata",
132+
]
133+
119134
def __init__(self, root_path):
120135
super(LocalFileDataSource, self).__init__()
121136

122137
self.root_path = os.path.abspath(root_path)
138+
139+
# Validate that the path is not in restricted locations
140+
for restricted_path in self.RESTRICTED_PATHS:
141+
if self.root_path.startswith(restricted_path):
142+
raise ValueError(
143+
f"Local Mode does not support mounting from restricted system paths. "
144+
f"Got: {root_path}"
145+
)
146+
123147
if not os.path.exists(self.root_path):
124148
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)
125149

sagemaker-core/src/sagemaker/core/local/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
4848
destination_directory
4949
"""
5050
full_path = os.path.join(destination_directory, relative_path)
51-
if os.path.exists(full_path):
52-
return
53-
54-
os.makedirs(destination_directory, relative_path)
51+
os.makedirs(full_path, exist_ok=True)
5552

5653

5754
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):

sagemaker-core/tests/unit/local/test_local_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
@patch("sagemaker.core.local.utils.os.path")
3636
@patch("sagemaker.core.local.utils.os")
3737
def test_copy_directory_structure(m_os, m_os_path):
38-
m_os_path.exists.return_value = False
38+
m_os_path.join.return_value = "/tmp/code/"
3939
copy_directory_structure("/tmp/", "code/")
40-
m_os.makedirs.assert_called_with("/tmp/", "code/")
40+
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)
4141

4242

4343
@patch("shutil.rmtree", Mock())

0 commit comments

Comments
 (0)