5353from diffusers .utils .torch_utils import is_compiled_module
5454
5555
56+ if is_wandb_available ():
57+ import wandb
58+
5659# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5760check_min_version ("0.28.0.dev0" )
5861
6467WANDB_TABLE_COL_NAMES = ["original_image" , "edited_image" , "edit_prompt" ]
6568
6669
70+ def log_validation (
71+ pipeline ,
72+ args ,
73+ accelerator ,
74+ generator ,
75+ ):
76+ logger .info (
77+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
78+ f" { args .validation_prompt } ."
79+ )
80+ pipeline = pipeline .to (accelerator .device )
81+ pipeline .set_progress_bar_config (disable = True )
82+
83+ # run inference
84+ original_image = download_image (args .val_image_url )
85+ edited_images = []
86+ if torch .backends .mps .is_available ():
87+ autocast_ctx = nullcontext ()
88+ else :
89+ autocast_ctx = torch .autocast (accelerator .device .type )
90+
91+ with autocast_ctx :
92+ for _ in range (args .num_validation_images ):
93+ edited_images .append (
94+ pipeline (
95+ args .validation_prompt ,
96+ image = original_image ,
97+ num_inference_steps = 20 ,
98+ image_guidance_scale = 1.5 ,
99+ guidance_scale = 7 ,
100+ generator = generator ,
101+ ).images [0 ]
102+ )
103+
104+ for tracker in accelerator .trackers :
105+ if tracker .name == "wandb" :
106+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
107+ for edited_image in edited_images :
108+ wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
109+ tracker .log ({"validation" : wandb_table })
110+
111+
67112def parse_args ():
68113 parser = argparse .ArgumentParser (description = "Simple example of a training script for InstructPix2Pix." )
69114 parser .add_argument (
@@ -411,11 +456,6 @@ def main():
411456
412457 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
413458
414- if args .report_to == "wandb" :
415- if not is_wandb_available ():
416- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
417- import wandb
418-
419459 # Make one log on every process with the configuration for debugging.
420460 logging .basicConfig (
421461 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -517,7 +557,8 @@ def save_model_hook(models, weights, output_dir):
517557 model .save_pretrained (os .path .join (output_dir , "unet" ))
518558
519559 # make sure to pop weight so that corresponding model is not saved again
520- weights .pop ()
560+ if weights :
561+ weights .pop ()
521562
522563 def load_model_hook (models , input_dir ):
523564 if args .use_ema :
@@ -923,11 +964,6 @@ def collate_fn(examples):
923964 and (args .validation_prompt is not None )
924965 and (epoch % args .validation_epochs == 0 )
925966 ):
926- logger .info (
927- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
928- f" { args .validation_prompt } ."
929- )
930- # create pipeline
931967 if args .use_ema :
932968 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
933969 ema_unet .store (unet .parameters ())
@@ -942,38 +978,14 @@ def collate_fn(examples):
942978 variant = args .variant ,
943979 torch_dtype = weight_dtype ,
944980 )
945- pipeline = pipeline .to (accelerator .device )
946- pipeline .set_progress_bar_config (disable = True )
947-
948- # run inference
949- original_image = download_image (args .val_image_url )
950- edited_images = []
951- if torch .backends .mps .is_available ():
952- autocast_ctx = nullcontext ()
953- else :
954- autocast_ctx = torch .autocast (accelerator .device .type )
955-
956- with autocast_ctx :
957- for _ in range (args .num_validation_images ):
958- edited_images .append (
959- pipeline (
960- args .validation_prompt ,
961- image = original_image ,
962- num_inference_steps = 20 ,
963- image_guidance_scale = 1.5 ,
964- guidance_scale = 7 ,
965- generator = generator ,
966- ).images [0 ]
967- )
968-
969- for tracker in accelerator .trackers :
970- if tracker .name == "wandb" :
971- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
972- for edited_image in edited_images :
973- wandb_table .add_data (
974- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
975- )
976- tracker .log ({"validation" : wandb_table })
981+
982+ log_validation (
983+ pipeline ,
984+ args ,
985+ accelerator ,
986+ generator ,
987+ )
988+
977989 if args .use_ema :
978990 # Switch back to the original UNet parameters.
979991 ema_unet .restore (unet .parameters ())
@@ -984,15 +996,14 @@ def collate_fn(examples):
984996 # Create the pipeline using the trained modules and save it.
985997 accelerator .wait_for_everyone ()
986998 if accelerator .is_main_process :
987- unet = unwrap_model (unet )
988999 if args .use_ema :
9891000 ema_unet .copy_to (unet .parameters ())
9901001
9911002 pipeline = StableDiffusionInstructPix2PixPipeline .from_pretrained (
9921003 args .pretrained_model_name_or_path ,
9931004 text_encoder = unwrap_model (text_encoder ),
9941005 vae = unwrap_model (vae ),
995- unet = unet ,
1006+ unet = unwrap_model ( unet ) ,
9961007 revision = args .revision ,
9971008 variant = args .variant ,
9981009 )
@@ -1006,31 +1017,13 @@ def collate_fn(examples):
10061017 ignore_patterns = ["step_*" , "epoch_*" ],
10071018 )
10081019
1009- if args .validation_prompt is not None :
1010- edited_images = []
1011- pipeline = pipeline .to (accelerator .device )
1012- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1013- for _ in range (args .num_validation_images ):
1014- edited_images .append (
1015- pipeline (
1016- args .validation_prompt ,
1017- image = original_image ,
1018- num_inference_steps = 20 ,
1019- image_guidance_scale = 1.5 ,
1020- guidance_scale = 7 ,
1021- generator = generator ,
1022- ).images [0 ]
1023- )
1024-
1025- for tracker in accelerator .trackers :
1026- if tracker .name == "wandb" :
1027- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1028- for edited_image in edited_images :
1029- wandb_table .add_data (
1030- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1031- )
1032- tracker .log ({"test" : wandb_table })
1033-
1020+ if (args .val_image_url is not None ) and (args .validation_prompt is not None ):
1021+ log_validation (
1022+ pipeline ,
1023+ args ,
1024+ accelerator ,
1025+ generator ,
1026+ )
10341027 accelerator .end_training ()
10351028
10361029
0 commit comments