diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b5dd2ecbef..bac0d33925 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -74,6 +74,20 @@ WAITING_DOT_NUMBER = 10 MAX_ITEMS = 100 PAGE_SIZE = 10 +_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators + +_SENSITIVE_SYSTEM_PATHS = [ + abspath(os.path.expanduser("~/.aws")), + abspath(os.path.expanduser("~/.ssh")), + abspath(os.path.expanduser("~/.kube")), + abspath(os.path.expanduser("~/.docker")), + abspath(os.path.expanduser("~/.config")), + abspath(os.path.expanduser("~/.credentials")), + "/etc", + "/root", + "/var/lib", + "/opt/ml/metadata", +] logger = logging.getLogger(__name__) @@ -607,11 +621,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) +def _validate_source_directory(source_directory): + """Validate that source_directory is safe to use. + + Ensures the source directory path does not access restricted system locations. + + Args: + source_directory (str): The source directory path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not source_directory or source_directory.lower().startswith("s3://"): + # S3 paths and None are safe + return + + # Resolve symlinks to get the actual path + abs_source = abspath(realpath(source_directory)) + + # Check if the source path is under any sensitive directory + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if abs_source != "/" and abs_source.startswith(sensitive_path): + raise ValueError( + f"source_directory cannot access sensitive system paths. " + f"Got: {source_directory} (resolved to {abs_source})" + ) + + +def _validate_dependency_path(dependency): + """Validate that a dependency path is safe to use. + + Ensures the dependency path does not access restricted system locations. + + Args: + dependency (str): The dependency path to validate. + + Raises: + ValueError: If the path is not allowed. + """ + if not dependency: + return + + # Resolve symlinks to get the actual path + abs_dependency = abspath(realpath(dependency)) + + # Check if the dependency path is under any sensitive directory + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if abs_dependency != "/" and abs_dependency.startswith(sensitive_path): + raise ValueError( + f"dependency path cannot access sensitive system paths. " + f"Got: {dependency} (resolved to {abs_dependency})" + ) + + def _create_or_update_code_dir( model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") + resolved_code_dir = _get_resolved_path(code_dir) + + # Validate that code_dir does not resolve to a sensitive system path + for sensitive_path in _SENSITIVE_SYSTEM_PATHS: + if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): + raise ValueError( + f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" + ) + if source_directory and source_directory.lower().startswith("s3://"): local_code_path = os.path.join(tmp, "local_code.tar.gz") download_file_from_url(source_directory, local_code_path, sagemaker_session) @@ -620,6 +696,8 @@ def _create_or_update_code_dir( custom_extractall_tarfile(t, code_dir) elif source_directory: + # Validate source_directory for security + _validate_source_directory(source_directory) if os.path.exists(code_dir): shutil.rmtree(code_dir) shutil.copytree(source_directory, code_dir) @@ -635,6 +713,8 @@ def _create_or_update_code_dir( raise for dependency in dependencies: + # Validate dependency path for security + _validate_dependency_path(dependency) lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) @@ -1646,6 +1726,38 @@ def _get_safe_members(members): yield file_info +def _validate_extracted_paths(extract_path): + """Validate that extracted paths remain within the expected directory. + + Performs post-extraction validation to ensure all extracted files and directories + are within the intended extraction path. + + Args: + extract_path (str): The path where files were extracted. + + Raises: + ValueError: If any extracted file is outside the expected extraction path. + """ + base = _get_resolved_path(extract_path) + + for root, dirs, files in os.walk(extract_path): + # Check directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + resolved = _get_resolved_path(dir_path) + if not resolved.startswith(base): + logger.error("Extracted directory escaped extraction path: %s", dir_path) + raise ValueError(f"Extracted path outside expected directory: {dir_path}") + + # Check files + for file_name in files: + file_path = os.path.join(root, file_name) + resolved = _get_resolved_path(file_path) + if not resolved.startswith(base): + logger.error("Extracted file escaped extraction path: %s", file_path) + raise ValueError(f"Extracted path outside expected directory: {file_path}") + + def custom_extractall_tarfile(tar, extract_path): """Extract a tarfile, optionally using data_filter if available. @@ -1666,6 +1778,8 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + # Re-validate extracted paths to catch symlink race conditions + _validate_extracted_paths(extract_path) def can_model_package_source_uri_autopopulate(source_uri: str): diff --git a/sagemaker-core/src/sagemaker/core/iterators.py b/sagemaker-core/src/sagemaker/core/iterators.py index a921742764..60914cbdd0 100644 --- a/sagemaker-core/src/sagemaker/core/iterators.py +++ b/sagemaker-core/src/sagemaker/core/iterators.py @@ -17,6 +17,7 @@ import io from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure +from sagemaker.core.common_utils import _MAX_BUFFER_SIZE def handle_stream_errors(chunk): @@ -182,5 +183,15 @@ def __next__(self): # print and move on to next response byte print("Unknown event type:" + chunk) continue + + # Check buffer size before writing to prevent unbounded memory consumption + chunk_size = len(chunk["PayloadPart"]["Bytes"]) + current_size = self.buffer.getbuffer().nbytes + if current_size + chunk_size > _MAX_BUFFER_SIZE: + raise RuntimeError( + f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. " + f"No newline found in stream." + ) + self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/sagemaker-core/src/sagemaker/core/local/data.py b/sagemaker-core/src/sagemaker/core/local/data.py index c3113835ca..087e06c0bc 100644 --- a/sagemaker-core/src/sagemaker/core/local/data.py +++ b/sagemaker-core/src/sagemaker/core/local/data.py @@ -24,6 +24,7 @@ from six.moves.urllib.parse import urlparse import sagemaker.core +from sagemaker.core.common_utils import _SENSITIVE_SYSTEM_PATHS def get_data_source_instance(data_source, sagemaker_session): @@ -120,6 +121,15 @@ def __init__(self, root_path): super(LocalFileDataSource, self).__init__() self.root_path = os.path.abspath(root_path) + + # Validate that the path is not in restricted locations + for restricted_path in _SENSITIVE_SYSTEM_PATHS: + if self.root_path != "/" and self.root_path.startswith(restricted_path): + raise ValueError( + f"Local Mode does not support mounting from restricted system paths. " + f"Got: {root_path}" + ) + if not os.path.exists(self.root_path): raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) diff --git a/sagemaker-core/src/sagemaker/core/local/utils.py b/sagemaker-core/src/sagemaker/core/local/utils.py index 5b173cd994..4b8cdead66 100644 --- a/sagemaker-core/src/sagemaker/core/local/utils.py +++ b/sagemaker-core/src/sagemaker/core/local/utils.py @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path): destination_directory """ full_path = os.path.join(destination_directory, relative_path) - if os.path.exists(full_path): - return - - os.makedirs(destination_directory, relative_path) + os.makedirs(full_path, exist_ok=True) def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): diff --git a/sagemaker-core/tests/unit/local/test_local_utils.py b/sagemaker-core/tests/unit/local/test_local_utils.py index 5dccbe3899..e5d51e7e75 100644 --- a/sagemaker-core/tests/unit/local/test_local_utils.py +++ b/sagemaker-core/tests/unit/local/test_local_utils.py @@ -35,9 +35,9 @@ @patch("sagemaker.core.local.utils.os.path") @patch("sagemaker.core.local.utils.os") def test_copy_directory_structure(m_os, m_os_path): - m_os_path.exists.return_value = False + m_os_path.join.return_value = "/tmp/code/" copy_directory_structure("/tmp/", "code/") - m_os.makedirs.assert_called_with("/tmp/", "code/") + m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True) @patch("shutil.rmtree", Mock()) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index 4052c46f26..291b667191 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -2209,3 +2209,293 @@ def test_nested_set_dict_multiple_keys(self): d = {} nested_set_dict(d, ["a", "b", "c"], "value") assert d["a"]["b"]["c"] == "value" + + + +class TestValidateSourceDirectory: + """Test _validate_source_directory function.""" + + def test_validate_source_directory_none(self): + """Test with None source directory.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Should not raise + _validate_source_directory(None) + + def test_validate_source_directory_s3_path(self): + """Test with S3 path.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Should not raise for S3 paths + _validate_source_directory("s3://my-bucket/my-code") + + def test_validate_source_directory_valid_local_path(self): + """Test with valid local path.""" + from sagemaker.core.common_utils import _validate_source_directory + + with tempfile.TemporaryDirectory() as tmpdir: + # Should not raise for valid local paths + _validate_source_directory(tmpdir) + + def test_validate_source_directory_sensitive_path_aws(self): + """Test rejection of ~/.aws path.""" + from sagemaker.core.common_utils import _validate_source_directory + + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(aws_dir) + + def test_validate_source_directory_sensitive_path_ssh(self): + """Test rejection of ~/.ssh path.""" + from sagemaker.core.common_utils import _validate_source_directory + + ssh_dir = os.path.expanduser("~/.ssh") + if os.path.exists(ssh_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory(ssh_dir) + + def test_validate_source_directory_sensitive_path_root(self): + """Test rejection of /root path.""" + from sagemaker.core.common_utils import _validate_source_directory + + # Test with /root which is a sensitive path + if os.path.exists("/root") and os.access("/root", os.R_OK): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_source_directory("/root") + + def test_validate_source_directory_symlink_resolution(self): + """Test that symlinks are resolved correctly.""" + from sagemaker.core.common_utils import _validate_source_directory + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a real directory + real_dir = os.path.join(tmpdir, "real_code") + os.makedirs(real_dir) + + # Create a symlink to it + symlink_path = os.path.join(tmpdir, "link_to_code") + os.symlink(real_dir, symlink_path) + + # Should not raise - symlink should be resolved and validated + _validate_source_directory(symlink_path) + + +class TestValidateDependencyPath: + """Test _validate_dependency_path function.""" + + def test_validate_dependency_path_none(self): + """Test with None dependency.""" + from sagemaker.core.common_utils import _validate_dependency_path + + # Should not raise + _validate_dependency_path(None) + + def test_validate_dependency_path_valid_local_path(self): + """Test with valid local path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + with tempfile.TemporaryDirectory() as tmpdir: + # Should not raise for valid local paths + _validate_dependency_path(tmpdir) + + def test_validate_dependency_path_sensitive_path_aws(self): + """Test rejection of ~/.aws path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(aws_dir) + + def test_validate_dependency_path_sensitive_path_credentials(self): + """Test rejection of ~/.credentials path.""" + from sagemaker.core.common_utils import _validate_dependency_path + + creds_dir = os.path.expanduser("~/.credentials") + if os.path.exists(creds_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _validate_dependency_path(creds_dir) + + def test_validate_dependency_path_symlink_resolution(self): + """Test that symlinks are resolved correctly.""" + from sagemaker.core.common_utils import _validate_dependency_path + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a real directory + real_dir = os.path.join(tmpdir, "real_lib") + os.makedirs(real_dir) + + # Create a symlink to it + symlink_path = os.path.join(tmpdir, "link_to_lib") + os.symlink(real_dir, symlink_path) + + # Should not raise - symlink should be resolved and validated + _validate_dependency_path(symlink_path) + + +class TestCreateOrUpdateCodeDir: + """Test _create_or_update_code_dir function.""" + + def test_create_or_update_code_dir_basic(self): + """Test basic code directory creation.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Should create code directory and copy inference script + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + assert os.path.exists(code_dir) + assert os.path.exists(os.path.join(code_dir, "inference.py")) + + def test_create_or_update_code_dir_with_source_directory(self): + """Test code directory creation with source directory.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + source_dir = os.path.join(tmpdir, "source") + os.makedirs(source_dir) + with open(os.path.join(source_dir, "app.py"), "w") as f: + f.write("# app code") + + # Should copy source directory to code directory + _create_or_update_code_dir( + model_dir, + "inference.py", + source_dir, + [], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + assert os.path.exists(code_dir) + assert os.path.exists(os.path.join(code_dir, "app.py")) + + def test_create_or_update_code_dir_with_dependencies(self): + """Test code directory creation with dependencies.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + dep_dir = os.path.join(tmpdir, "my_lib") + os.makedirs(dep_dir) + with open(os.path.join(dep_dir, "helper.py"), "w") as f: + f.write("# helper code") + + # Should create code directory with dependencies + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [dep_dir], + None, + tmpdir, + ) + + code_dir = os.path.join(model_dir, "code") + lib_dir = os.path.join(code_dir, "lib") + assert os.path.exists(lib_dir) + assert os.path.exists(os.path.join(lib_dir, "my_lib", "helper.py")) + + def test_create_or_update_code_dir_rejects_sensitive_paths(self): + """Test that code_dir validation rejects sensitive system paths.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a model_dir that would resolve to a sensitive path + # This is tricky to test without mocking, so we'll mock _get_resolved_path + with patch("sagemaker.core.common_utils._get_resolved_path") as mock_resolve: + mock_resolve.return_value = "/etc" + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + # Should raise ValueError for sensitive path + with pytest.raises(ValueError, match="Invalid code_dir path"): + _create_or_update_code_dir( + model_dir, + "inference.py", + None, + [], + None, + tmpdir, + ) + + def test_create_or_update_code_dir_validates_source_directory(self): + """Test that source_directory is validated.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Try to use a sensitive path as source_directory + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _create_or_update_code_dir( + model_dir, + "inference.py", + aws_dir, + [], + None, + tmpdir, + ) + + def test_create_or_update_code_dir_validates_dependencies(self): + """Test that dependencies are validated.""" + from sagemaker.core.common_utils import _create_or_update_code_dir + + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + inference_script = os.path.join(tmpdir, "inference.py") + with open(inference_script, "w") as f: + f.write("# inference code") + + # Try to use a sensitive path as dependency + aws_dir = os.path.expanduser("~/.aws") + if os.path.exists(aws_dir): + with pytest.raises(ValueError, match="cannot access sensitive system paths"): + _create_or_update_code_dir( + model_dir, + inference_script, + None, + [aws_dir], + None, + tmpdir, + )