3838from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3939from huggingface_hub import create_repo , upload_folder
4040from packaging import version
41- from peft import LoraConfig
41+ from peft import LoraConfig , set_peft_model_state_dict
4242from peft .utils import get_peft_model_state_dict
4343from PIL import Image
4444from PIL .ImageOps import exif_transpose
5858)
5959from diffusers .loaders import LoraLoaderMixin
6060from diffusers .optimization import get_scheduler
61- from diffusers .training_utils import compute_snr
61+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params , compute_snr
6262from diffusers .utils import (
6363 check_min_version ,
6464 convert_all_state_dict_to_peft ,
6565 convert_state_dict_to_diffusers ,
6666 convert_state_dict_to_kohya ,
67+ convert_unet_state_dict_to_peft ,
6768 is_wandb_available ,
6869)
6970from diffusers .utils .import_utils import is_xformers_available
@@ -1292,17 +1293,6 @@ def main(args):
12921293 else :
12931294 param .requires_grad = False
12941295
1295- # Make sure the trainable params are in float32.
1296- if args .mixed_precision == "fp16" :
1297- models = [unet ]
1298- if args .train_text_encoder :
1299- models .extend ([text_encoder_one , text_encoder_two ])
1300- for model in models :
1301- for param in model .parameters ():
1302- # only upcast trainable parameters (LoRA) into fp32
1303- if param .requires_grad :
1304- param .data = param .to (torch .float32 )
1305-
13061296 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
13071297 def save_model_hook (models , weights , output_dir ):
13081298 if accelerator .is_main_process :
@@ -1358,17 +1348,34 @@ def load_model_hook(models, input_dir):
13581348 raise ValueError (f"unexpected save model: { model .__class__ } " )
13591349
13601350 lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
1361- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
13621351
1363- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
1364- LoraLoaderMixin .load_lora_into_text_encoder (
1365- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
1366- )
1352+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1353+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
1354+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
1355+ if incompatible_keys is not None :
1356+ # check only for unexpected keys
1357+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1358+ if unexpected_keys :
1359+ logger .warning (
1360+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1361+ f" { unexpected_keys } . "
1362+ )
13671363
1368- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
1369- LoraLoaderMixin .load_lora_into_text_encoder (
1370- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
1371- )
1364+ if args .train_text_encoder :
1365+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
1366+
1367+ _set_state_dict_into_text_encoder (
1368+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_two_
1369+ )
1370+
1371+ # Make sure the trainable params are in float32. This is again needed since the base models
1372+ # are in `weight_dtype`. More details:
1373+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1374+ if args .mixed_precision == "fp16" :
1375+ models = [unet_ ]
1376+ if args .train_text_encoder :
1377+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
1378+ cast_training_params (models )
13721379
13731380 accelerator .register_save_state_pre_hook (save_model_hook )
13741381 accelerator .register_load_state_pre_hook (load_model_hook )
@@ -1383,6 +1390,13 @@ def load_model_hook(models, input_dir):
13831390 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
13841391 )
13851392
1393+ # Make sure the trainable params are in float32.
1394+ if args .mixed_precision == "fp16" :
1395+ models = [unet ]
1396+ if args .train_text_encoder :
1397+ models .extend ([text_encoder_one , text_encoder_two ])
1398+ cast_training_params (models , dtype = torch .float32 )
1399+
13861400 unet_lora_parameters = list (filter (lambda p : p .requires_grad , unet .parameters ()))
13871401
13881402 if args .train_text_encoder :
0 commit comments