@@ -898,6 +898,12 @@ def prepare_latents(
898898 f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
899899 )
900900
901+ latents_mean = latents_std = None
902+ if hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None :
903+ latents_mean = torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
904+ if hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None :
905+ latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
906+
901907 # Offload text encoder if `enable_model_cpu_offload` was enabled
902908 if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
903909 self .text_encoder_2 .to ("cpu" )
@@ -935,7 +941,12 @@ def prepare_latents(
935941 self .vae .to (dtype )
936942
937943 init_latents = init_latents .to (dtype )
938- init_latents = self .vae .config .scaling_factor * init_latents
944+ if latents_mean is not None and latents_std is not None :
945+ latents_mean = latents_mean .to (device = self .device , dtype = dtype )
946+ latents_std = latents_std .to (device = self .device , dtype = dtype )
947+ init_latents = (init_latents - latents_mean ) * self .vae .config .scaling_factor / latents_std
948+ else :
949+ init_latents = self .vae .config .scaling_factor * init_latents
939950
940951 if batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] == 0 :
941952 # expand init_latents for batch_size
0 commit comments