5050from diffusers .optimization import get_scheduler
5151from diffusers .utils import check_min_version , is_wandb_available
5252from diffusers .utils .import_utils import is_xformers_available
53+ from diffusers .utils .torch_utils import is_compiled_module
5354
5455
5556MAX_SEQ_LENGTH = 77
@@ -926,6 +927,11 @@ def load_model_hook(models, input_dir):
926927 else :
927928 raise ValueError ("xformers is not available. Make sure it is installed correctly" )
928929
930+ def unwrap_model (model ):
931+ model = accelerator .unwrap_model (model )
932+ model = model ._orig_mod if is_compiled_module (model ) else model
933+ return model
934+
929935 if args .gradient_checkpointing :
930936 unet .enable_gradient_checkpointing ()
931937
@@ -935,9 +941,9 @@ def load_model_hook(models, input_dir):
935941 " doing mixed precision training, copy of the weights should still be float32."
936942 )
937943
938- if accelerator . unwrap_model (t2iadapter ).dtype != torch .float32 :
944+ if unwrap_model (t2iadapter ).dtype != torch .float32 :
939945 raise ValueError (
940- f"Controlnet loaded as datatype { accelerator . unwrap_model (t2iadapter ).dtype } . { low_precision_error_string } "
946+ f"Controlnet loaded as datatype { unwrap_model (t2iadapter ).dtype } . { low_precision_error_string } "
941947 )
942948
943949 # Enable TF32 for faster training on Ampere GPUs,
@@ -1198,7 +1204,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11981204 encoder_hidden_states = batch ["prompt_ids" ],
11991205 added_cond_kwargs = batch ["unet_added_conditions" ],
12001206 down_block_additional_residuals = down_block_additional_residuals ,
1201- ).sample
1207+ return_dict = False ,
1208+ )[0 ]
12021209
12031210 # Denoise the latents
12041211 denoised_latents = model_pred * (- sigmas ) + noisy_latents
@@ -1279,7 +1286,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12791286 # Create the pipeline using using the trained modules and save it.
12801287 accelerator .wait_for_everyone ()
12811288 if accelerator .is_main_process :
1282- t2iadapter = accelerator . unwrap_model (t2iadapter )
1289+ t2iadapter = unwrap_model (t2iadapter )
12831290 t2iadapter .save_pretrained (args .output_dir )
12841291
12851292 if args .push_to_hub :
0 commit comments