diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 3b3e180e63f41..d1a7f097d0c90 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None: Lightning users. """ - if not torch.cuda.is_initialized(): - return - - message = ( - "Lightning can't create new processes if CUDA is already initialized. Did you manually call" - " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" - " other way? Please remove any such calls, or change the selected strategy." - ) - if _IS_INTERACTIVE: - message += " You will have to restart the Python kernel." - raise RuntimeError(message) + # Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA + # is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries) + # while still catching actual problematic cases where CUDA context was created before forking. + _is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None) + if _is_in_bad_fork is not None and callable(_is_in_bad_fork) and _is_in_bad_fork(): + message = ( + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, " + "you must use the 'spawn' start method or avoid CUDA initialization in the main process." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) + + # Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions) + if _is_in_bad_fork is None and torch.cuda.is_initialized(): + message = ( + "Lightning can't create new processes if CUDA is already initialized. Did you manually call" + " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) def _disable_module_memory_sharing(data: Any) -> Any: diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 5bb85e070f17d..7f7414a161903 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -98,6 +98,17 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method): launcher.launch(function=Mock()) +@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) +@mock.patch("torch.cuda._is_in_bad_fork", return_value=True) +@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp") +def test_check_for_bad_cuda_fork_with_is_in_bad_fork(mp_mock, _, start_method): + """Test the new _is_in_bad_fork detection when available.""" + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + with pytest.raises(RuntimeError, match="Cannot re-initialize CUDA in forked subprocess"): + launcher.launch(function=Mock()) + + def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") with (