Skip to content

Commit c1dc2ae

Browse files
authored
Fix multi-gpu case for train_cm_ct_unconditional.py (#8653)
* Fix multi-gpu case * Prefer previously created `unwrap_model()` function For `torch.compile()` generalizability * `chore: update unwrap_model() function to use accelerator.unwrap_model()`
1 parent e15a8e7 commit c1dc2ae

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/research_projects/consistency_training/train_cm_ct_unconditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,7 @@ def unwrap_model(model):
11951195

11961196
# Resolve the c parameter for the Pseudo-Huber loss
11971197
if args.huber_c is None:
1198-
args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels)
1198+
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
11991199

12001200
# Get current number of discretization steps N according to our discretization curriculum
12011201
current_discretization_steps = get_discretization_steps(

0 commit comments

Comments
 (0)