@@ -607,6 +607,95 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
607607 shutil .move (tmp_model_path , repacked_model_uri .replace ("file://" , "" ))
608608
609609
610+ def _validate_source_directory (source_directory ):
611+ """Validate that source_directory is safe to use.
612+
613+ Ensures the source directory path does not access restricted system locations.
614+
615+ Args:
616+ source_directory (str): The source directory path to validate.
617+
618+ Raises:
619+ ValueError: If the path is not allowed.
620+ """
621+ if not source_directory or source_directory .lower ().startswith ("s3://" ):
622+ # S3 paths and None are safe
623+ return
624+
625+ abs_source = abspath (source_directory )
626+
627+ # Blocklist of sensitive directories that should not be accessible
628+ sensitive_paths = [
629+ abspath (os .path .expanduser ("~/.aws" )),
630+ abspath (os .path .expanduser ("~/.ssh" )),
631+ abspath (os .path .expanduser ("~/.kube" )),
632+ abspath (os .path .expanduser ("~/.docker" )),
633+ abspath (os .path .expanduser ("~/.config" )),
634+ abspath (os .path .expanduser ("~/.credentials" )),
635+ "/etc" ,
636+ "/root" ,
637+ "/home" ,
638+ "/var/lib" ,
639+ "/opt/ml/metadata" ,
640+ ]
641+
642+ # Check if the source path is under any sensitive directory
643+ for sensitive_path in sensitive_paths :
644+ if abs_source .startswith (sensitive_path ):
645+ raise ValueError (
646+ f"source_directory cannot access sensitive system paths. "
647+ f"Got: { source_directory } (resolved to { abs_source } )"
648+ )
649+
650+ # Check for symlinks to prevent symlink-based escapes
651+ if os .path .islink (abs_source ):
652+ raise ValueError (f"source_directory cannot be a symlink: { source_directory } " )
653+
654+
655+ def _validate_dependency_path (dependency ):
656+ """Validate that a dependency path is safe to use.
657+
658+ Ensures the dependency path does not access restricted system locations.
659+
660+ Args:
661+ dependency (str): The dependency path to validate.
662+
663+ Raises:
664+ ValueError: If the path is not allowed.
665+ """
666+ if not dependency :
667+ return
668+
669+ abs_dependency = abspath (dependency )
670+
671+ # Blocklist of sensitive directories that should not be accessible
672+ sensitive_paths = [
673+ abspath (os .path .expanduser ("~/.aws" )),
674+ abspath (os .path .expanduser ("~/.ssh" )),
675+ abspath (os .path .expanduser ("~/.kube" )),
676+ abspath (os .path .expanduser ("~/.docker" )),
677+ abspath (os .path .expanduser ("~/.config" )),
678+ abspath (os .path .expanduser ("~/.credentials" )),
679+ "/etc" ,
680+ "/root" ,
681+ "/home" ,
682+ "/var/lib" ,
683+ "/opt/ml/metadata" ,
684+ ]
685+
686+ # Check if the dependency path is under any sensitive directory
687+ for sensitive_path in sensitive_paths :
688+ if abs_dependency .startswith (sensitive_path ):
689+ raise ValueError (
690+ f"dependency path cannot access sensitive system paths. "
691+ f"Got: { dependency } (resolved to { abs_dependency } )"
692+ )
693+
694+ # Check for symlinks to prevent symlink-based escapes
695+ if os .path .islink (abs_dependency ):
696+ raise ValueError (f"dependency path cannot be a symlink: { dependency } " )
697+
698+
610699def _create_or_update_code_dir (
611700 model_dir , inference_script , source_directory , dependencies , sagemaker_session , tmp
612701):
@@ -620,6 +709,8 @@ def _create_or_update_code_dir(
620709 custom_extractall_tarfile (t , code_dir )
621710
622711 elif source_directory :
712+ # Validate source_directory for security
713+ _validate_source_directory (source_directory )
623714 if os .path .exists (code_dir ):
624715 shutil .rmtree (code_dir )
625716 shutil .copytree (source_directory , code_dir )
@@ -635,6 +726,8 @@ def _create_or_update_code_dir(
635726 raise
636727
637728 for dependency in dependencies :
729+ # Validate dependency path for security
730+ _validate_dependency_path (dependency )
638731 lib_dir = os .path .join (code_dir , "lib" )
639732 if os .path .isdir (dependency ):
640733 shutil .copytree (dependency , os .path .join (lib_dir , os .path .basename (dependency )))
@@ -1646,6 +1739,38 @@ def _get_safe_members(members):
16461739 yield file_info
16471740
16481741
1742+ def _validate_extracted_paths (extract_path ):
1743+ """Validate that extracted paths remain within the expected directory.
1744+
1745+ Performs post-extraction validation to ensure all extracted files and directories
1746+ are within the intended extraction path.
1747+
1748+ Args:
1749+ extract_path (str): The path where files were extracted.
1750+
1751+ Raises:
1752+ ValueError: If any extracted file is outside the expected extraction path.
1753+ """
1754+ base = _get_resolved_path (extract_path )
1755+
1756+ for root , dirs , files in os .walk (extract_path ):
1757+ # Check directories
1758+ for dir_name in dirs :
1759+ dir_path = os .path .join (root , dir_name )
1760+ resolved = _get_resolved_path (dir_path )
1761+ if not resolved .startswith (base ):
1762+ logger .error ("Extracted directory escaped extraction path: %s" , dir_path )
1763+ raise ValueError (f"Extracted path outside expected directory: { dir_path } " )
1764+
1765+ # Check files
1766+ for file_name in files :
1767+ file_path = os .path .join (root , file_name )
1768+ resolved = _get_resolved_path (file_path )
1769+ if not resolved .startswith (base ):
1770+ logger .error ("Extracted file escaped extraction path: %s" , file_path )
1771+ raise ValueError (f"Extracted path outside expected directory: { file_path } " )
1772+
1773+
16491774def custom_extractall_tarfile (tar , extract_path ):
16501775 """Extract a tarfile, optionally using data_filter if available.
16511776
@@ -1666,6 +1791,8 @@ def custom_extractall_tarfile(tar, extract_path):
16661791 tar .extractall (path = extract_path , filter = "data" )
16671792 else :
16681793 tar .extractall (path = extract_path , members = _get_safe_members (tar ))
1794+ # Re-validate extracted paths to catch symlink race conditions
1795+ _validate_extracted_paths (extract_path )
16691796
16701797
16711798def can_model_package_source_uri_autopopulate (source_uri : str ):
0 commit comments