-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix ddp_notebook CUDA fork check to allow passive initialization #21402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None: | |
| Lightning users. | ||
|
|
||
| """ | ||
| if not torch.cuda.is_initialized(): | ||
| return | ||
|
|
||
| message = ( | ||
| "Lightning can't create new processes if CUDA is already initialized. Did you manually call" | ||
| " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" | ||
| " other way? Please remove any such calls, or change the selected strategy." | ||
| ) | ||
| if _IS_INTERACTIVE: | ||
| message += " You will have to restart the Python kernel." | ||
| raise RuntimeError(message) | ||
| # Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA | ||
| # is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries) | ||
| # while still catching actual problematic cases where CUDA context was created before forking. | ||
| _is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None) | ||
| if _is_in_bad_fork is not None and _is_in_bad_fork(): | ||
| message = ( | ||
| "Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, " | ||
| "you must use the 'spawn' start method or avoid CUDA initialization in the main process." | ||
| ) | ||
| if _IS_INTERACTIVE: | ||
| message += " You will have to restart the Python kernel." | ||
| raise RuntimeError(message) | ||
|
|
||
|
||
| # Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions) | ||
| if _is_in_bad_fork is None and torch.cuda.is_initialized(): | ||
| message = ( | ||
| "Lightning can't create new processes if CUDA is already initialized. Did you manually call" | ||
| " `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any" | ||
| " other way? Please remove any such calls, or change the selected strategy." | ||
| ) | ||
| if _IS_INTERACTIVE: | ||
| message += " You will have to restart the Python kernel." | ||
| raise RuntimeError(message) | ||
|
Comment on lines
+198
to
+220
|
||
|
|
||
|
|
||
| def _disable_module_memory_sharing(data: Any) -> Any: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code uses
getattrto check if_is_in_bad_forkexists, but it doesn't verify if the returned value is callable. While PyTorch's_is_in_bad_forkis indeed a function, it's better practice to verify callability when usinggetattron potentially undefined attributes, especially for internal/private APIs that could change.Consider adding a callable check: