Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/lightning/fabric/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None:
Lightning users.

"""
if not torch.cuda.is_initialized():
return

message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)
# Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA
# is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries)
# while still catching actual problematic cases where CUDA context was created before forking.
_is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None)
if _is_in_bad_fork is not None and _is_in_bad_fork():
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code uses getattr to check if _is_in_bad_fork exists, but it doesn't verify if the returned value is callable. While PyTorch's _is_in_bad_fork is indeed a function, it's better practice to verify callability when using getattr on potentially undefined attributes, especially for internal/private APIs that could change.

Consider adding a callable check:

_is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None)
if _is_in_bad_fork is not None and callable(_is_in_bad_fork) and _is_in_bad_fork():
Suggested change
if _is_in_bad_fork is not None and _is_in_bad_fork():
if _is_in_bad_fork is not None and callable(_is_in_bad_fork) and _is_in_bad_fork():

Copilot uses AI. Check for mistakes.
message = (
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, "
"you must use the 'spawn' start method or avoid CUDA initialization in the main process."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)

Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace found at the end of the line. Remove the extra whitespace.

Copilot uses AI. Check for mistakes.
# Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions)
if _is_in_bad_fork is None and torch.cuda.is_initialized():
message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)
Comment on lines +198 to +220
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing test test_check_for_bad_cuda_fork mocks torch.cuda.is_initialized() to return True, which will only test the fallback path in the new implementation (when _is_in_bad_fork is None). This test should be updated to also verify the new behavior when torch.cuda._is_in_bad_fork is available and returns True.

Copilot uses AI. Check for mistakes.


def _disable_module_memory_sharing(data: Any) -> Any:
Expand Down
Loading