diff --git a/sagemaker-core/src/sagemaker/core/local/utils.py b/sagemaker-core/src/sagemaker/core/local/utils.py index 58ef3e7781..5b173cd994 100644 --- a/sagemaker-core/src/sagemaker/core/local/utils.py +++ b/sagemaker-core/src/sagemaker/core/local/utils.py @@ -137,7 +137,11 @@ def get_child_process_ids(pid): Returns: (List[int]): Child process ids """ - cmd = f"pgrep -P {pid}".split() + if not str(pid).isdigit(): + raise ValueError("Invalid PID") + + cmd = ["pgrep", "-P", str(pid)] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() if err: diff --git a/sagemaker-core/tests/unit/local/test_local_utils.py b/sagemaker-core/tests/unit/local/test_local_utils.py index 432c456d54..5dccbe3899 100644 --- a/sagemaker-core/tests/unit/local/test_local_utils.py +++ b/sagemaker-core/tests/unit/local/test_local_utils.py @@ -103,21 +103,24 @@ def test_recursive_copy(copy_tree, m_os_path): @patch("sagemaker.core.local.utils.os") @patch("sagemaker.core.local.utils.get_child_process_ids") def test_kill_child_processes(m_get_child_process_ids, m_os): - m_get_child_process_ids.return_value = ["child_pids"] - kill_child_processes("pid") - m_os.kill.assert_called_with("child_pids", 15) + m_get_child_process_ids.return_value = ["345"] + kill_child_processes("123") + m_os.kill.assert_called_with("345", 15) @patch("sagemaker.core.local.utils.subprocess") def test_get_child_process_ids(m_subprocess): - cmd = "pgrep -P pid".split() + cmd = "pgrep -P 123".split() process_mock = Mock() attrs = {"communicate.return_value": (b"\n", False), "returncode": 0} process_mock.configure_mock(**attrs) m_subprocess.Popen.return_value = process_mock - get_child_process_ids("pid") + get_child_process_ids("123") m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE) +def test_get_child_process_ids_exception(): + with pytest.raises(ValueError, match="Invalid PID"): + get_child_process_ids("abc") @patch("sagemaker.core.local.utils.subprocess") def test_get_docker_host(m_subprocess):