Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))


def _validate_source_directory(source_directory):
"""Validate that source_directory is safe to use.

Ensures the source directory path does not access restricted system locations.

Args:
source_directory (str): The source directory path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not source_directory or source_directory.lower().startswith("s3://"):
# S3 paths and None are safe
return

abs_source = abspath(source_directory)

# Blocklist of sensitive directories that should not be accessible
sensitive_paths = [
abspath(os.path.expanduser("~/.aws")),
abspath(os.path.expanduser("~/.ssh")),
abspath(os.path.expanduser("~/.kube")),
abspath(os.path.expanduser("~/.docker")),
abspath(os.path.expanduser("~/.config")),
abspath(os.path.expanduser("~/.credentials")),
"/etc",
"/root",
"/home",
"/var/lib",
"/opt/ml/metadata",
]

# Check if the source path is under any sensitive directory
for sensitive_path in sensitive_paths:
if abs_source.startswith(sensitive_path):
raise ValueError(
f"source_directory cannot access sensitive system paths. "
f"Got: {source_directory} (resolved to {abs_source})"
)

# Check for symlinks to prevent symlink-based escapes
if os.path.islink(abs_source):
raise ValueError(f"source_directory cannot be a symlink: {source_directory}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious: Are symlinks not supported for source path ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be supported... and this is a good point! Now that I think about it, we don't have to completely prevent sym-links, we just need to resolve the sym-link with os.path.realpath() before validation. This prevents the bug with race conditions. I will make that change



def _validate_dependency_path(dependency):
"""Validate that a dependency path is safe to use.

Ensures the dependency path does not access restricted system locations.

Args:
dependency (str): The dependency path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not dependency:
return

abs_dependency = abspath(dependency)

# Blocklist of sensitive directories that should not be accessible
sensitive_paths = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sensitive_paths duplicate of _validate_source_directory's sensitive path? if so, can we refactor to use a common source of sensitive_paths

I can also see another list RESTRICTED_PATH defined in sagemaker-core/src/sagemaker/core/local/data.py, that seem duplicate.
If these are duplicate , can we fix redundancy and use a common list of sensitive/restricted paths ?

abspath(os.path.expanduser("~/.aws")),
abspath(os.path.expanduser("~/.ssh")),
abspath(os.path.expanduser("~/.kube")),
abspath(os.path.expanduser("~/.docker")),
abspath(os.path.expanduser("~/.config")),
abspath(os.path.expanduser("~/.credentials")),
"/etc",
"/root",
"/home",
"/var/lib",
"/opt/ml/metadata",
]

# Check if the dependency path is under any sensitive directory
for sensitive_path in sensitive_paths:
if abs_dependency.startswith(sensitive_path):
raise ValueError(
f"dependency path cannot access sensitive system paths. "
f"Got: {dependency} (resolved to {abs_dependency})"
)

# Check for symlinks to prevent symlink-based escapes
if os.path.islink(abs_dependency):
raise ValueError(f"dependency path cannot be a symlink: {dependency}")


def _create_or_update_code_dir(
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
):
Expand All @@ -620,6 +709,8 @@ def _create_or_update_code_dir(
custom_extractall_tarfile(t, code_dir)

elif source_directory:
# Validate source_directory for security
_validate_source_directory(source_directory)
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_directory, code_dir)
Expand All @@ -635,6 +726,8 @@ def _create_or_update_code_dir(
raise

for dependency in dependencies:
# Validate dependency path for security
_validate_dependency_path(dependency)
lib_dir = os.path.join(code_dir, "lib")
if os.path.isdir(dependency):
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
Expand Down Expand Up @@ -1646,6 +1739,38 @@ def _get_safe_members(members):
yield file_info


def _validate_extracted_paths(extract_path):
"""Validate that extracted paths remain within the expected directory.

Performs post-extraction validation to ensure all extracted files and directories
are within the intended extraction path.

Args:
extract_path (str): The path where files were extracted.

Raises:
ValueError: If any extracted file is outside the expected extraction path.
"""
base = _get_resolved_path(extract_path)

for root, dirs, files in os.walk(extract_path):
# Check directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
resolved = _get_resolved_path(dir_path)
if not resolved.startswith(base):
logger.error("Extracted directory escaped extraction path: %s", dir_path)
raise ValueError(f"Extracted path outside expected directory: {dir_path}")

# Check files
for file_name in files:
file_path = os.path.join(root, file_name)
resolved = _get_resolved_path(file_path)
if not resolved.startswith(base):
logger.error("Extracted file escaped extraction path: %s", file_path)
raise ValueError(f"Extracted path outside expected directory: {file_path}")


def custom_extractall_tarfile(tar, extract_path):
"""Extract a tarfile, optionally using data_filter if available.

Expand All @@ -1666,6 +1791,8 @@ def custom_extractall_tarfile(tar, extract_path):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
# Re-validate extracted paths to catch symlink race conditions
_validate_extracted_paths(extract_path)


def can_model_package_source_uri_autopopulate(source_uri: str):
Expand Down
13 changes: 13 additions & 0 deletions sagemaker-core/src/sagemaker/core/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __next__(self):
class LineIterator(BaseIterator):
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""

# Maximum buffer size to prevent unbounded memory consumption (10 MB)
MAX_BUFFER_SIZE = 10 * 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to some constants file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Adding it to common_utils.py


def __init__(self, event_stream):
"""Initialises a LineIterator Iterator object

Expand Down Expand Up @@ -182,5 +185,15 @@ def __next__(self):
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

# Check buffer size before writing to prevent unbounded memory consumption
chunk_size = len(chunk["PayloadPart"]["Bytes"])
current_size = self.buffer.getbuffer().nbytes
if current_size + chunk_size > self.MAX_BUFFER_SIZE:
raise RuntimeError(
f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. "
f"No newline found in stream."
)

self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
24 changes: 24 additions & 0 deletions sagemaker-core/src/sagemaker/core/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,34 @@ def get_root_dir(self):
class LocalFileDataSource(DataSource):
"""Represents a data source within the local filesystem."""

# Blocklist of sensitive directories that should not be accessible
RESTRICTED_PATHS = [
os.path.abspath(os.path.expanduser("~/.aws")),
os.path.abspath(os.path.expanduser("~/.ssh")),
os.path.abspath(os.path.expanduser("~/.kube")),
os.path.abspath(os.path.expanduser("~/.docker")),
os.path.abspath(os.path.expanduser("~/.config")),
os.path.abspath(os.path.expanduser("~/.credentials")),
"/etc",
"/root",
"/home",
"/var/lib",
"/opt/ml/metadata",
]

def __init__(self, root_path):
super(LocalFileDataSource, self).__init__()

self.root_path = os.path.abspath(root_path)

# Validate that the path is not in restricted locations
for restricted_path in self.RESTRICTED_PATHS:
if self.root_path.startswith(restricted_path):
raise ValueError(
f"Local Mode does not support mounting from restricted system paths. "
f"Got: {root_path}"
)

if not os.path.exists(self.root_path):
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)

Expand Down
5 changes: 1 addition & 4 deletions sagemaker-core/src/sagemaker/core/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
destination_directory
"""
full_path = os.path.join(destination_directory, relative_path)
if os.path.exists(full_path):
return

os.makedirs(destination_directory, relative_path)
os.makedirs(full_path, exist_ok=True)


def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-core/tests/unit/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
@patch("sagemaker.core.local.utils.os.path")
@patch("sagemaker.core.local.utils.os")
def test_copy_directory_structure(m_os, m_os_path):
m_os_path.exists.return_value = False
m_os_path.join.return_value = "/tmp/code/"
copy_directory_structure("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)


@patch("shutil.rmtree", Mock())
Expand Down
Loading