-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add input validation and resource management improvements V3 #5418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
532fe77
4ef8cf0
d2151c2
849b169
8ccb2aa
0d451cd
ca2da09
a7a0563
4b8adb7
1202579
b967400
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -607,6 +607,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): | |
| shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) | ||
|
|
||
|
|
||
| def _validate_source_directory(source_directory): | ||
| """Validate that source_directory is safe to use. | ||
|
|
||
| Ensures the source directory path does not access restricted system locations. | ||
|
|
||
| Args: | ||
| source_directory (str): The source directory path to validate. | ||
|
|
||
| Raises: | ||
| ValueError: If the path is not allowed. | ||
| """ | ||
| if not source_directory or source_directory.lower().startswith("s3://"): | ||
| # S3 paths and None are safe | ||
| return | ||
|
|
||
| abs_source = abspath(source_directory) | ||
|
|
||
| # Blocklist of sensitive directories that should not be accessible | ||
| sensitive_paths = [ | ||
| abspath(os.path.expanduser("~/.aws")), | ||
| abspath(os.path.expanduser("~/.ssh")), | ||
| abspath(os.path.expanduser("~/.kube")), | ||
| abspath(os.path.expanduser("~/.docker")), | ||
| abspath(os.path.expanduser("~/.config")), | ||
| abspath(os.path.expanduser("~/.credentials")), | ||
| "/etc", | ||
| "/root", | ||
| "/home", | ||
| "/var/lib", | ||
| "/opt/ml/metadata", | ||
| ] | ||
|
|
||
| # Check if the source path is under any sensitive directory | ||
| for sensitive_path in sensitive_paths: | ||
| if abs_source.startswith(sensitive_path): | ||
| raise ValueError( | ||
| f"source_directory cannot access sensitive system paths. " | ||
| f"Got: {source_directory} (resolved to {abs_source})" | ||
| ) | ||
|
|
||
| # Check for symlinks to prevent symlink-based escapes | ||
| if os.path.islink(abs_source): | ||
| raise ValueError(f"source_directory cannot be a symlink: {source_directory}") | ||
|
|
||
|
|
||
| def _validate_dependency_path(dependency): | ||
| """Validate that a dependency path is safe to use. | ||
|
|
||
| Ensures the dependency path does not access restricted system locations. | ||
|
|
||
| Args: | ||
| dependency (str): The dependency path to validate. | ||
|
|
||
| Raises: | ||
| ValueError: If the path is not allowed. | ||
| """ | ||
| if not dependency: | ||
| return | ||
|
|
||
| abs_dependency = abspath(dependency) | ||
|
|
||
| # Blocklist of sensitive directories that should not be accessible | ||
| sensitive_paths = [ | ||
|
||
| abspath(os.path.expanduser("~/.aws")), | ||
| abspath(os.path.expanduser("~/.ssh")), | ||
| abspath(os.path.expanduser("~/.kube")), | ||
| abspath(os.path.expanduser("~/.docker")), | ||
| abspath(os.path.expanduser("~/.config")), | ||
| abspath(os.path.expanduser("~/.credentials")), | ||
| "/etc", | ||
| "/root", | ||
| "/home", | ||
| "/var/lib", | ||
| "/opt/ml/metadata", | ||
| ] | ||
|
|
||
| # Check if the dependency path is under any sensitive directory | ||
| for sensitive_path in sensitive_paths: | ||
| if abs_dependency.startswith(sensitive_path): | ||
| raise ValueError( | ||
| f"dependency path cannot access sensitive system paths. " | ||
| f"Got: {dependency} (resolved to {abs_dependency})" | ||
| ) | ||
|
|
||
| # Check for symlinks to prevent symlink-based escapes | ||
| if os.path.islink(abs_dependency): | ||
| raise ValueError(f"dependency path cannot be a symlink: {dependency}") | ||
|
|
||
|
|
||
| def _create_or_update_code_dir( | ||
| model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp | ||
| ): | ||
|
|
@@ -620,6 +709,8 @@ def _create_or_update_code_dir( | |
| custom_extractall_tarfile(t, code_dir) | ||
|
|
||
| elif source_directory: | ||
| # Validate source_directory for security | ||
| _validate_source_directory(source_directory) | ||
| if os.path.exists(code_dir): | ||
| shutil.rmtree(code_dir) | ||
| shutil.copytree(source_directory, code_dir) | ||
|
|
@@ -635,6 +726,8 @@ def _create_or_update_code_dir( | |
| raise | ||
|
|
||
| for dependency in dependencies: | ||
| # Validate dependency path for security | ||
| _validate_dependency_path(dependency) | ||
| lib_dir = os.path.join(code_dir, "lib") | ||
| if os.path.isdir(dependency): | ||
| shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) | ||
|
|
@@ -1646,6 +1739,38 @@ def _get_safe_members(members): | |
| yield file_info | ||
|
|
||
|
|
||
| def _validate_extracted_paths(extract_path): | ||
| """Validate that extracted paths remain within the expected directory. | ||
|
|
||
| Performs post-extraction validation to ensure all extracted files and directories | ||
| are within the intended extraction path. | ||
|
|
||
| Args: | ||
| extract_path (str): The path where files were extracted. | ||
|
|
||
| Raises: | ||
| ValueError: If any extracted file is outside the expected extraction path. | ||
| """ | ||
| base = _get_resolved_path(extract_path) | ||
|
|
||
| for root, dirs, files in os.walk(extract_path): | ||
| # Check directories | ||
| for dir_name in dirs: | ||
| dir_path = os.path.join(root, dir_name) | ||
| resolved = _get_resolved_path(dir_path) | ||
| if not resolved.startswith(base): | ||
| logger.error("Extracted directory escaped extraction path: %s", dir_path) | ||
| raise ValueError(f"Extracted path outside expected directory: {dir_path}") | ||
|
|
||
| # Check files | ||
| for file_name in files: | ||
| file_path = os.path.join(root, file_name) | ||
| resolved = _get_resolved_path(file_path) | ||
| if not resolved.startswith(base): | ||
| logger.error("Extracted file escaped extraction path: %s", file_path) | ||
| raise ValueError(f"Extracted path outside expected directory: {file_path}") | ||
|
|
||
|
|
||
| def custom_extractall_tarfile(tar, extract_path): | ||
| """Extract a tarfile, optionally using data_filter if available. | ||
|
|
||
|
|
@@ -1666,6 +1791,8 @@ def custom_extractall_tarfile(tar, extract_path): | |
| tar.extractall(path=extract_path, filter="data") | ||
| else: | ||
| tar.extractall(path=extract_path, members=_get_safe_members(tar)) | ||
| # Re-validate extracted paths to catch symlink race conditions | ||
| _validate_extracted_paths(extract_path) | ||
|
|
||
|
|
||
| def can_model_package_source_uri_autopopulate(source_uri: str): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,6 +114,9 @@ def __next__(self): | |
| class LineIterator(BaseIterator): | ||
| """A helper class for parsing the byte Event Stream input to provide Line iteration.""" | ||
|
|
||
| # Maximum buffer size to prevent unbounded memory consumption (10 MB) | ||
| MAX_BUFFER_SIZE = 10 * 1024 * 1024 | ||
|
||
|
|
||
| def __init__(self, event_stream): | ||
| """Initialises a LineIterator Iterator object | ||
|
|
||
|
|
@@ -182,5 +185,15 @@ def __next__(self): | |
| # print and move on to next response byte | ||
| print("Unknown event type:" + chunk) | ||
| continue | ||
|
|
||
| # Check buffer size before writing to prevent unbounded memory consumption | ||
| chunk_size = len(chunk["PayloadPart"]["Bytes"]) | ||
| current_size = self.buffer.getbuffer().nbytes | ||
| if current_size + chunk_size > self.MAX_BUFFER_SIZE: | ||
| raise RuntimeError( | ||
| f"Line buffer exceeded maximum size of {self.MAX_BUFFER_SIZE} bytes. " | ||
| f"No newline found in stream." | ||
| ) | ||
|
|
||
| self.buffer.seek(0, io.SEEK_END) | ||
| self.buffer.write(chunk["PayloadPart"]["Bytes"]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious: Are symlinks not supported for source path ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be supported... and this is a good point! Now that I think about it, we don't have to completely prevent sym-links, we just need to resolve the sym-link with os.path.realpath() before validation. This prevents the bug with race conditions. I will make that change