|
44 | 44 | from transformers import AutoTokenizer, PretrainedConfig |
45 | 45 |
|
46 | 46 | import diffusers |
47 | | -from diffusers import ( |
48 | | - AutoencoderKL, |
49 | | - DDPMScheduler, |
50 | | - StableDiffusionXLPipeline, |
51 | | - UNet2DConditionModel, |
52 | | -) |
| 47 | +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel |
53 | 48 | from diffusers.optimization import get_scheduler |
54 | 49 | from diffusers.training_utils import EMAModel, compute_snr |
55 | 50 | from diffusers.utils import check_min_version, is_wandb_available |
56 | 51 | from diffusers.utils.import_utils import is_xformers_available |
| 52 | +from diffusers.utils.torch_utils import is_compiled_module |
57 | 53 |
|
58 | 54 |
|
59 | 55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. |
@@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca |
508 | 504 | prompt_embeds = text_encoder( |
509 | 505 | text_input_ids.to(text_encoder.device), |
510 | 506 | output_hidden_states=True, |
| 507 | + return_dict=False, |
511 | 508 | ) |
512 | 509 |
|
513 | 510 | # We are only ALWAYS interested in the pooled output of the final text encoder |
514 | 511 | pooled_prompt_embeds = prompt_embeds[0] |
515 | | - prompt_embeds = prompt_embeds.hidden_states[-2] |
| 512 | + prompt_embeds = prompt_embeds[-1][-2] |
516 | 513 | bs_embed, seq_len, _ = prompt_embeds.shape |
517 | 514 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
518 | 515 | prompt_embeds_list.append(prompt_embeds) |
@@ -955,6 +952,12 @@ def collate_fn(examples): |
955 | 952 | if accelerator.is_main_process: |
956 | 953 | accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) |
957 | 954 |
|
| 955 | + # Function for unwraping if torch.compile() was used in accelerate. |
| 956 | + def unwrap_model(model): |
| 957 | + model = accelerator.unwrap_model(model) |
| 958 | + model = model._orig_mod if is_compiled_module(model) else model |
| 959 | + return model |
| 960 | + |
958 | 961 | # Train! |
959 | 962 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
960 | 963 |
|
@@ -1054,8 +1057,12 @@ def compute_time_ids(original_size, crops_coords_top_left): |
1054 | 1057 | pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) |
1055 | 1058 | unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) |
1056 | 1059 | model_pred = unet( |
1057 | | - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions |
1058 | | - ).sample |
| 1060 | + noisy_model_input, |
| 1061 | + timesteps, |
| 1062 | + prompt_embeds, |
| 1063 | + added_cond_kwargs=unet_added_conditions, |
| 1064 | + return_dict=False, |
| 1065 | + )[0] |
1059 | 1066 |
|
1060 | 1067 | # Get the target for loss depending on the prediction type |
1061 | 1068 | if args.prediction_type is not None: |
@@ -1206,7 +1213,7 @@ def compute_time_ids(original_size, crops_coords_top_left): |
1206 | 1213 |
|
1207 | 1214 | accelerator.wait_for_everyone() |
1208 | 1215 | if accelerator.is_main_process: |
1209 | | - unet = accelerator.unwrap_model(unet) |
| 1216 | + unet = unwrap_model(unet) |
1210 | 1217 | if args.use_ema: |
1211 | 1218 | ema_unet.copy_to(unet.parameters()) |
1212 | 1219 |
|
|
0 commit comments