Skip to content

Commit cd95b19

Browse files
pintaoz-awspintaoz
andauthored
Fix _run_shell_cmd() to use list input (#5422)
Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent 2c7c4b5 commit cd95b19

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path):
9494
class RuntimeEnvironmentManager:
9595
"""Runtime Environment Manager class to manage runtime environment."""
9696

97+
def _validate_path(self, path: str) -> str:
98+
"""Validate and sanitize file path to prevent path traversal attacks.
99+
100+
Args:
101+
path (str): The file path to validate
102+
103+
Returns:
104+
str: The validated absolute path
105+
106+
Raises:
107+
ValueError: If the path is invalid or contains suspicious patterns
108+
"""
109+
if not path:
110+
raise ValueError("Path cannot be empty")
111+
112+
# Get absolute path to prevent path traversal
113+
abs_path = os.path.abspath(path)
114+
115+
# Check for null bytes (common in path traversal attacks)
116+
if '\x00' in path:
117+
raise ValueError(f"Invalid path contains null byte: {path}")
118+
119+
return abs_path
120+
121+
def _validate_env_name(self, env_name: str) -> None:
122+
"""Validate conda environment name to prevent command injection.
123+
124+
Args:
125+
env_name (str): The environment name to validate
126+
127+
Raises:
128+
ValueError: If the environment name contains invalid characters
129+
"""
130+
if not env_name:
131+
raise ValueError("Environment name cannot be empty")
132+
133+
# Allow only alphanumeric, underscore, and hyphen
134+
import re
135+
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
136+
raise ValueError(
137+
f"Invalid environment name '{env_name}'. "
138+
"Only alphanumeric characters, underscores, and hyphens are allowed."
139+
)
140+
97141
def snapshot(self, dependencies: str = None) -> str:
98142
"""Creates snapshot of the user's environment
99143
@@ -252,39 +296,50 @@ def _is_file_exists(self, dependencies):
252296

253297
def _install_requirements_txt(self, local_path, python_executable):
254298
"""Install requirements.txt file"""
255-
cmd = f"{python_executable} -m pip install -r {local_path} -U"
256-
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
299+
# Validate path to prevent command injection
300+
validated_path = self._validate_path(local_path)
301+
cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"]
302+
logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd())
257303
_run_shell_cmd(cmd)
258-
logger.info("Command %s ran successfully", cmd)
304+
logger.info("Command %s ran successfully", " ".join(cmd))
259305

260306
def _create_conda_env(self, env_name, local_path):
261307
"""Create conda env using conda yml file"""
308+
# Validate inputs to prevent command injection
309+
self._validate_env_name(env_name)
310+
validated_path = self._validate_path(local_path)
262311

263-
cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}"
264-
logger.info("Creating conda environment %s using: %s.", env_name, cmd)
312+
cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path]
313+
logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd))
265314
_run_shell_cmd(cmd)
266315
logger.info("Conda environment %s created successfully.", env_name)
267316

268317
def _install_req_txt_in_conda_env(self, env_name, local_path):
269318
"""Install requirements.txt in the given conda environment"""
319+
# Validate inputs to prevent command injection
320+
self._validate_env_name(env_name)
321+
validated_path = self._validate_path(local_path)
270322

271-
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
272-
logger.info("Activating conda env and installing requirements: %s", cmd)
323+
cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"]
324+
logger.info("Activating conda env and installing requirements: %s", " ".join(cmd))
273325
_run_shell_cmd(cmd)
274326
logger.info("Requirements installed successfully in conda env %s", env_name)
275327

276328
def _update_conda_env(self, env_name, local_path):
277329
"""Update conda env using conda yml file"""
330+
# Validate inputs to prevent command injection
331+
self._validate_env_name(env_name)
332+
validated_path = self._validate_path(local_path)
278333

279-
cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}"
280-
logger.info("Updating conda env: %s", cmd)
334+
cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path]
335+
logger.info("Updating conda env: %s", " ".join(cmd))
281336
_run_shell_cmd(cmd)
282337
logger.info("Conda env %s updated succesfully", env_name)
283338

284339
def _export_conda_env_from_prefix(self, prefix, local_path):
285340
"""Export the conda env to a conda yml file"""
286341

287-
cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}"
342+
cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path]
288343
logger.info("Exporting conda environment: %s", cmd)
289344
_run_shell_cmd(cmd)
290345
logger.info("Conda environment %s exported successfully", prefix)
@@ -402,19 +457,26 @@ def _run_pre_execution_command_script(script_path: str):
402457
return return_code, error_logs
403458

404459

405-
def _run_shell_cmd(cmd: str):
460+
def _run_shell_cmd(cmd: list):
406461
"""This method runs a given shell command using subprocess
407462
408-
Raises RuntimeEnvironmentError if the command fails
463+
Args:
464+
cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt'])
465+
466+
Raises:
467+
RuntimeEnvironmentError: If the command fails
468+
ValueError: If cmd is not a list
409469
"""
470+
if not isinstance(cmd, list):
471+
raise ValueError("Command must be a list of arguments for security reasons")
410472

411-
process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
473+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
412474

413475
_log_output(process)
414476
error_logs = _log_error(process)
415477
return_code = process.wait()
416478
if return_code:
417-
error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}"
479+
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}"
418480
raise RuntimeEnvironmentError(error_message)
419481

420482

sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_e
490490
mock_popen.return_value = mock_process
491491
mock_log_error.return_value = ""
492492

493-
_run_shell_cmd("echo test")
493+
_run_shell_cmd(["echo", "test"])
494494

495495
mock_popen.assert_called_once()
496496

@@ -505,7 +505,7 @@ def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output,
505505
mock_log_error.return_value = "Error message"
506506

507507
with pytest.raises(RuntimeEnvironmentError):
508-
_run_shell_cmd("false")
508+
_run_shell_cmd(["false"])
509509

510510

511511
class TestLogOutput:

0 commit comments

Comments
 (0)