4646from diffusers .training_utils import cast_training_params , compute_snr
4747from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
4848from diffusers .utils .import_utils import is_xformers_available
49+ from diffusers .utils .torch_utils import is_compiled_module
4950
5051
5152# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -596,6 +597,11 @@ def tokenize_captions(examples, is_train=True):
596597 ]
597598 )
598599
600+ def unwrap_model (model ):
601+ model = accelerator .unwrap_model (model )
602+ model = model ._orig_mod if is_compiled_module (model ) else model
603+ return model
604+
599605 def preprocess_train (examples ):
600606 images = [image .convert ("RGB" ) for image in examples [image_column ]]
601607 examples ["pixel_values" ] = [train_transforms (image ) for image in images ]
@@ -729,7 +735,7 @@ def collate_fn(examples):
729735 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
730736
731737 # Get the text embedding for conditioning
732- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
738+ encoder_hidden_states = text_encoder (batch ["input_ids" ], return_dict = False )[0 ]
733739
734740 # Get the target for loss depending on the prediction type
735741 if args .prediction_type is not None :
@@ -744,7 +750,7 @@ def collate_fn(examples):
744750 raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
745751
746752 # Predict the noise residual and compute loss
747- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ). sample
753+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , return_dict = False )[ 0 ]
748754
749755 if args .snr_gamma is None :
750756 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -809,7 +815,7 @@ def collate_fn(examples):
809815 save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
810816 accelerator .save_state (save_path )
811817
812- unwrapped_unet = accelerator . unwrap_model (unet )
818+ unwrapped_unet = unwrap_model (unet )
813819 unet_lora_state_dict = convert_state_dict_to_diffusers (
814820 get_peft_model_state_dict (unwrapped_unet )
815821 )
@@ -837,7 +843,7 @@ def collate_fn(examples):
837843 # create pipeline
838844 pipeline = DiffusionPipeline .from_pretrained (
839845 args .pretrained_model_name_or_path ,
840- unet = accelerator . unwrap_model (unet ),
846+ unet = unwrap_model (unet ),
841847 revision = args .revision ,
842848 variant = args .variant ,
843849 torch_dtype = weight_dtype ,
@@ -878,7 +884,7 @@ def collate_fn(examples):
878884 if accelerator .is_main_process :
879885 unet = unet .to (torch .float32 )
880886
881- unwrapped_unet = accelerator . unwrap_model (unet )
887+ unwrapped_unet = unwrap_model (unet )
882888 unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unwrapped_unet ))
883889 StableDiffusionPipeline .save_lora_weights (
884890 save_directory = args .output_dir ,
0 commit comments