From 532fe772dab8fabb15d20f8d3c8b0899f3fb2c34 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 15 Dec 2025 17:50:15 -0800 Subject: [PATCH 1/5] Add input validation and resource management improvements V3 --- .../src/sagemaker/core/common_utils.py | 127 ++++++++++++++++++ .../src/sagemaker/core/iterators.py | 13 ++ .../src/sagemaker/core/local/data.py | 24 ++++ .../src/sagemaker/core/local/utils.py | 5 +- .../tests/unit/local/test_local_utils.py | 4 +- 5 files changed, 167 insertions(+), 6 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b5dd2ecbef..d98886a285 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -607,6 +607,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) +def _validate_source_directory(source_directory): + """Validate that source_directory is safe to use. + + Ensures the source directory path does not access restricted system locations. + + Args: + source_directory (str): The source directory path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not source_directory or source_directory.lower().startswith("s3://"): + # S3 paths and None are safe + return + + abs_source = abspath(source_directory) + + # Blocklist of sensitive directories that should not be accessible + sensitive_paths = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + + # Check if the source path is under any sensitive directory + for sensitive_path in sensitive_paths: + if abs_source.startswith(sensitive_path): + raise ValueError( + f"source_directory cannot access sensitive system paths. " + f"Got: {source_directory} (resolved to {abs_source})" + ) + + # Check for symlinks to prevent symlink-based escapes + if os.path.islink(abs_source): + raise ValueError(f"source_directory cannot be a symlink: {source_directory}") + + +def _validate_dependency_path(dependency): + """Validate that a dependency path is safe to use. + + Ensures the dependency path does not access restricted system locations. + + Args: + dependency (str): The dependency path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not dependency: + return + + abs_dependency = abspath(dependency) + + # Blocklist of sensitive directories that should not be accessible + sensitive_paths = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + + # Check if the dependency path is under any sensitive directory + for sensitive_path in sensitive_paths: + if abs_dependency.startswith(sensitive_path): + raise ValueError( + f"dependency path cannot access sensitive system paths. " + f"Got: {dependency} (resolved to {abs_dependency})" + ) + + # Check for symlinks to prevent symlink-based escapes + if os.path.islink(abs_dependency): + raise ValueError(f"dependency path cannot be a symlink: {dependency}") + + def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): @@ -620,6 +709,8 @@ def _create_or_update_code_dir( custom_extractall_tarfile(t, code_dir) elif source_directory: + # Validate source_directory for security + _validate_source_directory(source_directory) if os.path.exists(code_dir): shutil.rmtree(code_dir) shutil.copytree(source_directory, code_dir) @@ -635,6 +726,8 @@ def _create_or_update_code_dir( raise for dependency in dependencies: + # Validate dependency path for security + _validate_dependency_path(dependency) lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) @@ -1646,6 +1739,38 @@ def _get_safe_members(members): yield file_info +def _validate_extracted_paths(extract_path): + """Validate that extracted paths remain within the expected directory. + + Performs post-extraction validation to ensure all extracted files and directories + are within the intended extraction path. + + Args: + extract_path (str): The path where files were extracted. + + Raises: + ValueError: If any extracted file is outside the expected extraction path. + """ + base = _get_resolved_path(extract_path) + + for root, dirs, files in os.walk(extract_path): + # Check directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + resolved = _get_resolved_path(dir_path) + if not resolved.startswith(base): + logger.error("Extracted directory escaped extraction path: %s", dir_path) + raise ValueError(f"Extracted path outside expected directory: {dir_path}") + + # Check files + for file_name in files: + file_path = os.path.join(root, file_name) + resolved = _get_resolved_path(file_path) + if not resolved.startswith(base): + logger.error("Extracted file escaped extraction path: %s", file_path) + raise ValueError(f"Extracted path outside expected directory: {file_path}") + + def custom_extractall_tarfile(tar, extract_path): """Extract a tarfile, optionally using data_filter if available. @@ -1666,6 +1791,8 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + # Re-validate extracted paths to catch symlink race conditions + _validate_extracted_paths(extract_path) def can_model_package_source_uri_autopopulate(source_uri: str): diff --git a/sagemaker-core/src/sagemaker/core/iterators.py b/sagemaker-core/src/sagemaker/core/iterators.py index a921742764..c25f26f496 100644 --- a/sagemaker-core/src/sagemaker/core/iterators.py +++ b/sagemaker-core/src/sagemaker/core/iterators.py @@ -114,6 +114,9 @@ def __next__(self): class LineIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Line iteration.""" + # Maximum buffer size to prevent unbounded memory consumption (10 MB) + MAX_BUFFER_SIZE = 10 * 1024 * 1024 + def __init__(self, event_stream): """Initialises a LineIterator Iterator object @@ -182,5 +185,15 @@ def __next__(self): # print and move on to next response byte print("Unknown event type:" + chunk) continue + + # Check buffer size before writing to prevent unbounded memory consumption + chunk_size = len(chunk["PayloadPart"]["Bytes"]) + current_size = self.buffer.getbuffer().nbytes + if current_size + chunk_size > self.MAX_BUFFER_SIZE: + raise RuntimeError( + f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " + f"No newline found in stream." + ) + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/sagemaker-core/src/sagemaker/core/local/data.py b/sagemaker-core/src/sagemaker/core/local/data.py index c3113835ca..f8d3892b0d 100644 --- a/sagemaker-core/src/sagemaker/core/local/data.py +++ b/sagemaker-core/src/sagemaker/core/local/data.py @@ -116,10 +116,34 @@ def get_root_dir(self): class LocalFileDataSource(DataSource): """Represents a data source within the local filesystem.""" + # Blocklist of sensitive directories that should not be accessible + RESTRICTED_PATHS = [ + os.path.abspath(os.path.expanduser("~/.aws")), + os.path.abspath(os.path.expanduser("~/.ssh")), + os.path.abspath(os.path.expanduser("~/.kube")), + os.path.abspath(os.path.expanduser("~/.docker")), + os.path.abspath(os.path.expanduser("~/.config")), + os.path.abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", + ] + def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) + + # Validate that the path is not in restricted locations + for restricted_path in self.RESTRICTED_PATHS: + if self.root_path.startswith(restricted_path): + raise ValueError( + f"Local Mode does not support mounting from restricted system paths. " + f"Got: {root_path}" + ) + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) diff --git a/sagemaker-core/src/sagemaker/core/local/utils.py b/sagemaker-core/src/sagemaker/core/local/utils.py index 58ef3e7781..a33bdf7455 100644 --- a/sagemaker-core/src/sagemaker/core/local/utils.py +++ b/sagemaker-core/src/sagemaker/core/local/utils.py @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path): destination_directory """ full_path = os.path.join(destination_directory, relative_path) - if os.path.exists(full_path): - return - - os.makedirs(destination_directory, relative_path) + os.makedirs(full_path, exist_ok=True) def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): diff --git a/sagemaker-core/tests/unit/local/test_local_utils.py b/sagemaker-core/tests/unit/local/test_local_utils.py index 432c456d54..845cd9a1bc 100644 --- a/sagemaker-core/tests/unit/local/test_local_utils.py +++ b/sagemaker-core/tests/unit/local/test_local_utils.py @@ -35,9 +35,9 @@ @patch("sagemaker.core.local.utils.os.path") @patch("sagemaker.core.local.utils.os") def test_copy_directory_structure(m_os, m_os_path): - m_os_path.exists.return_value = False + m_os_path.join.return_value = "/tmp/code/" copy_directory_structure("/tmp/", "code/") - m_os.makedirs.assert_called_with("/tmp/", "code/") + m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True) @patch("shutil.rmtree", Mock()) From d2151c23141d4ed0f84b3c29e9ab0065f6f445d7 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:04:34 -0800 Subject: [PATCH 2/5] Allowing for sym-links, better refactoring --- .../src/sagemaker/core/common_utils.py | 63 +++++++------------ .../src/sagemaker/core/iterators.py | 8 +-- .../src/sagemaker/core/local/data.py | 18 +----- 3 files changed, 26 insertions(+), 63 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index d98886a285..202c266144 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -74,6 +74,21 @@ WAITING_DOT_NUMBER = 10 MAX_ITEMS = 100 PAGE_SIZE = 10 +_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators + +_SENSITIVE_SYSTEM_PATHS = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/home", + "/var/lib", + "/opt/ml/metadata", +] logger = logging.getLogger(__name__) @@ -622,35 +637,17 @@ def _validate_source_directory(source_directory): # S3 paths and None are safe return - abs_source = abspath(source_directory) - - # Blocklist of sensitive directories that should not be accessible - sensitive_paths = [ - abspath(os.path.expanduser("~/.aws")), - abspath(os.path.expanduser("~/.ssh")), - abspath(os.path.expanduser("~/.kube")), - abspath(os.path.expanduser("~/.docker")), - abspath(os.path.expanduser("~/.config")), - abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] + # Resolve symlinks to get the actual path + abs_source = abspath(realpath(source_directory)) # Check if the source path is under any sensitive directory - for sensitive_path in sensitive_paths: + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: if abs_source.startswith(sensitive_path): raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" ) - # Check for symlinks to prevent symlink-based escapes - if os.path.islink(abs_source): - raise ValueError(f"source_directory cannot be a symlink: {source_directory}") - def _validate_dependency_path(dependency): """Validate that a dependency path is safe to use. @@ -666,35 +663,17 @@ def _validate_dependency_path(dependency): if not dependency: return - abs_dependency = abspath(dependency) - - # Blocklist of sensitive directories that should not be accessible - sensitive_paths = [ - abspath(os.path.expanduser("~/.aws")), - abspath(os.path.expanduser("~/.ssh")), - abspath(os.path.expanduser("~/.kube")), - abspath(os.path.expanduser("~/.docker")), - abspath(os.path.expanduser("~/.config")), - abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] + # Resolve symlinks to get the actual path + abs_dependency = abspath(realpath(dependency)) # Check if the dependency path is under any sensitive directory - for sensitive_path in sensitive_paths: + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: if abs_dependency.startswith(sensitive_path): raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" ) - # Check for symlinks to prevent symlink-based escapes - if os.path.islink(abs_dependency): - raise ValueError(f"dependency path cannot be a symlink: {dependency}") - def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp diff --git a/sagemaker-core/src/sagemaker/core/iterators.py b/sagemaker-core/src/sagemaker/core/iterators.py index c25f26f496..60914cbdd0 100644 --- a/sagemaker-core/src/sagemaker/core/iterators.py +++ b/sagemaker-core/src/sagemaker/core/iterators.py @@ -17,6 +17,7 @@ import io from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure +from sagemaker.core.common_utils import _MAX_BUFFER_SIZE def handle_stream_errors(chunk): @@ -114,9 +115,6 @@ def __next__(self): class LineIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Line iteration.""" - # Maximum buffer size to prevent unbounded memory consumption (10 MB) - MAX_BUFFER_SIZE = 10 * 1024 * 1024 - def __init__(self, event_stream): """Initialises a LineIterator Iterator object @@ -189,9 +187,9 @@ def __next__(self): # Check buffer size before writing to prevent unbounded memory consumption chunk_size = len(chunk["PayloadPart"]["Bytes"]) current_size = self.buffer.getbuffer().nbytes - if current_size + chunk_size > self.MAX_BUFFER_SIZE: + if current_size + chunk_size > _MAX_BUFFER_SIZE: raise RuntimeError( - f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " + f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. " f"No newline found in stream." ) diff --git a/sagemaker-core/src/sagemaker/core/local/data.py b/sagemaker-core/src/sagemaker/core/local/data.py index f8d3892b0d..6d011bee78 100644 --- a/sagemaker-core/src/sagemaker/core/local/data.py +++ b/sagemaker-core/src/sagemaker/core/local/data.py @@ -24,6 +24,7 @@ from six.moves.urllib.parse import urlparse import sagemaker.core +from sagemaker.core.common_utils import _SENSITIVE_SYSTEM_PATHS def get_data_source_instance(data_source, sagemaker_session): @@ -116,28 +117,13 @@ def get_root_dir(self): class LocalFileDataSource(DataSource): """Represents a data source within the local filesystem.""" - # Blocklist of sensitive directories that should not be accessible - RESTRICTED_PATHS = [ - os.path.abspath(os.path.expanduser("~/.aws")), - os.path.abspath(os.path.expanduser("~/.ssh")), - os.path.abspath(os.path.expanduser("~/.kube")), - os.path.abspath(os.path.expanduser("~/.docker")), - os.path.abspath(os.path.expanduser("~/.config")), - os.path.abspath(os.path.expanduser("~/.credentials")), - "/etc", - "/root", - "/home", - "/var/lib", - "/opt/ml/metadata", - ] - def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) # Validate that the path is not in restricted locations - for restricted_path in self.RESTRICTED_PATHS: + for restricted_path in _SENSITIVE_SYSTEM_PATHS: if self.root_path.startswith(restricted_path): raise ValueError( f"Local Mode does not support mounting from restricted system paths. " From ca2da0932929e3aac7b82a63ad3ba762ca4bb476 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:17:09 -0800 Subject: [PATCH 3/5] Removing home path and adding additional validaiton --- .../src/sagemaker/core/common_utils.py | 10 +- .../tests/unit/test_common_utils.py | 290 ++++++++++++++++++ 2 files changed, 299 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 202c266144..2f394ddcbf 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -85,7 +85,6 @@ abspath(os.path.expanduser("~/.credentials")), "/etc", "/root", - "/home", "/var/lib", "/opt/ml/metadata", ] @@ -680,6 +679,15 @@ def _create_or_update_code_dir( ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") + resolved_code_dir = _get_resolved_path(code_dir) + + # Validate that code_dir does not resolve to a sensitive system path + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if resolved_code_dir.startswith(sensitive_path): + raise ValueError( + f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" + ) + if source_directory and source_directory.lower().startswith("s3://"): local_code_path = os.path.join(tmp, "local_code.tar.gz") download_file_from_url(source_directory, local_code_path, sagemaker_session) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index 4052c46f26..291b667191 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -2209,3 +2209,293 @@ def test_nested_set_dict_multiple_keys(self): d = {} nested_set_dict(d, ["a", "b", "c"], "value") assert d["a"]["b"]["c"] == "value" + + + +class TestValidateSourceDirectory: + """Test _validate_source_directory function.""" + + def test_validate_source_directory_none(self): + """Test with None source directory.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Should not raise + _validate_source_directory(None) + + def test_validate_source_directory_s3_path(self): + """Test with S3 path.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Should not raise for S3 paths + _validate_source_directory("s3://my-bucket/my-code") + + def test_validate_source_directory_valid_local_path(self): + """Test with valid local path.""" + from sagemaker.core.common_utils import _validate_source_directory + + with tempfile.TemporaryDirectory() as tmpdir: + # Should not raise for valid local paths + _validate_source_directory(tmpdir) + + def test_validate_source_directory_sensitive_path_aws(self): + """Test rejection of ~/.aws path.""" + from sagemaker.core.common_utils import _validate_source_directory + + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(aws_dir) + + def test_validate_source_directory_sensitive_path_ssh(self): + """Test rejection of ~/.ssh path.""" + from sagemaker.core.common_utils import _validate_source_directory + + ssh_dir = os.path.expanduser("~/.ssh") + if os.path.exists(ssh_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(ssh_dir) + + def test_validate_source_directory_sensitive_path_root(self): + """Test rejection of /root path.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Test with /root which is a sensitive path + if os.path.exists("/root") and os.access("/root", os.R_OK): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory("/root") + + def test_validate_source_directory_symlink_resolution(self): + """Test that symlinks are resolved correctly.""" + from sagemaker.core.common_utils import _validate_source_directory + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a real directory + real_dir = os.path.join(tmpdir, "real_code") + os.makedirs(real_dir) + + # Create a symlink to it + symlink_path = os.path.join(tmpdir, "link_to_code") + os.symlink(real_dir, symlink_path) + + # Should not raise - symlink should be resolved and validated + _validate_source_directory(symlink_path) + + +class TestValidateDependencyPath: + """Test _validate_dependency_path function.""" + + def test_validate_dependency_path_none(self): + """Test with None dependency.""" + from sagemaker.core.common_utils import _validate_dependency_path + + # Should not raise + _validate_dependency_path(None) + + def test_validate_dependency_path_valid_local_path(self): + """Test with valid local path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + with tempfile.TemporaryDirectory() as tmpdir: + # Should not raise for valid local paths + _validate_dependency_path(tmpdir) + + def test_validate_dependency_path_sensitive_path_aws(self): + """Test rejection of ~/.aws path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(aws_dir) + + def test_validate_dependency_path_sensitive_path_credentials(self): + """Test rejection of ~/.credentials path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + creds_dir = os.path.expanduser("~/.credentials") + if os.path.exists(creds_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(creds_dir) + + def test_validate_dependency_path_symlink_resolution(self): + """Test that symlinks are resolved correctly.""" + from sagemaker.core.common_utils import _validate_dependency_path + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a real directory + real_dir = os.path.join(tmpdir, "real_lib") + os.makedirs(real_dir) + + # Create a symlink to it + symlink_path = os.path.join(tmpdir, "link_to_lib") + os.symlink(real_dir, symlink_path) + + # Should not raise - symlink should be resolved and validated + _validate_dependency_path(symlink_path) + + +class TestCreateOrUpdateCodeDir: + """Test _create_or_update_code_dir function.""" + + def test_create_or_update_code_dir_basic(self): + """Test basic code directory creation.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Should create code directory and copy inference script + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + assert os.path.exists(code_dir) + assert os.path.exists(os.path.join(code_dir, "inference.py")) + + def test_create_or_update_code_dir_with_source_directory(self): + """Test code directory creation with source directory.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + source_dir = os.path.join(tmpdir, "source") + os.makedirs(source_dir) + with open(os.path.join(source_dir, "app.py"), "w") as f: + f.write("# app code") + + # Should copy source directory to code directory + _create_or_update_code_dir( + model_dir, + "inference.py", + source_dir, + [], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + assert os.path.exists(code_dir) + assert os.path.exists(os.path.join(code_dir, "app.py")) + + def test_create_or_update_code_dir_with_dependencies(self): + """Test code directory creation with dependencies.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + dep_dir = os.path.join(tmpdir, "my_lib") + os.makedirs(dep_dir) + with open(os.path.join(dep_dir, "helper.py"), "w") as f: + f.write("# helper code") + + # Should create code directory with dependencies + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [dep_dir], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + lib_dir = os.path.join(code_dir, "lib") + assert os.path.exists(lib_dir) + assert os.path.exists(os.path.join(lib_dir, "my_lib", "helper.py")) + + def test_create_or_update_code_dir_rejects_sensitive_paths(self): + """Test that code_dir validation rejects sensitive system paths.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a model_dir that would resolve to a sensitive path + # This is tricky to test without mocking, so we'll mock _get_resolved_path + with patch("sagemaker.core.common_utils._get_resolved_path") as mock_resolve: + mock_resolve.return_value = "/etc" + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + # Should raise ValueError for sensitive path + with pytest.raises(ValueError, match="Invalid code_dir path"): + _create_or_update_code_dir( + model_dir, + "inference.py", + None, + [], + None, + tmpdir, + ) + + def test_create_or_update_code_dir_validates_source_directory(self): + """Test that source_directory is validated.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Try to use a sensitive path as source_directory + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _create_or_update_code_dir( + model_dir, + "inference.py", + aws_dir, + [], + None, + tmpdir, + ) + + def test_create_or_update_code_dir_validates_dependencies(self): + """Test that dependencies are validated.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Try to use a sensitive path as dependency + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [aws_dir], + None, + tmpdir, + ) From 4b8adb73d14de40547f996b347c44f71c090186d Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:44:09 -0800 Subject: [PATCH 4/5] Including check for root directory --- sagemaker-core/src/sagemaker/core/common_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 2f394ddcbf..586be0abf5 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -641,7 +641,7 @@ def _validate_source_directory(source_directory): # Check if the source path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_source.startswith(sensitive_path): + if abs_source != "/" and abs_source.startswith(sensitive_path): raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" @@ -667,7 +667,7 @@ def _validate_dependency_path(dependency): # Check if the dependency path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_dependency.startswith(sensitive_path): + if abs_dependency != "/" and abs_dependency.startswith(sensitive_path): raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" From 1202579c1814513fc9482f36716c93966cf596cf Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:02:51 -0800 Subject: [PATCH 5/5] Adding root directory validation to other helpers --- sagemaker-core/src/sagemaker/core/common_utils.py | 2 +- sagemaker-core/src/sagemaker/core/local/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 586be0abf5..bac0d33925 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -683,7 +683,7 @@ def _create_or_update_code_dir( # Validate that code_dir does not resolve to a sensitive system path for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if resolved_code_dir.startswith(sensitive_path): + if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): raise ValueError( f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" ) diff --git a/sagemaker-core/src/sagemaker/core/local/data.py b/sagemaker-core/src/sagemaker/core/local/data.py index 6d011bee78..087e06c0bc 100644 --- a/sagemaker-core/src/sagemaker/core/local/data.py +++ b/sagemaker-core/src/sagemaker/core/local/data.py @@ -124,7 +124,7 @@ def __init__(self, root_path): # Validate that the path is not in restricted locations for restricted_path in _SENSITIVE_SYSTEM_PATHS: - if self.root_path.startswith(restricted_path): + if self.root_path != "/" and self.root_path.startswith(restricted_path): raise ValueError( f"Local Mode does not support mounting from restricted system paths. " f"Got: {root_path}"