@@ -172,6 +172,7 @@ def __call__(
172172 prompt_embeds : Optional [torch .FloatTensor ] = None ,
173173 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
174174 ip_adapter_image : Optional [PipelineImageInput ] = None ,
175+ ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
175176 output_type : Optional [str ] = "pil" ,
176177 return_dict : bool = True ,
177178 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
@@ -296,21 +297,15 @@ def __call__(
296297 negative_prompt ,
297298 prompt_embeds ,
298299 negative_prompt_embeds ,
300+ ip_adapter_image ,
301+ ip_adapter_image_embeds ,
299302 callback_on_step_end_tensor_inputs ,
300303 )
301304 self ._guidance_scale = guidance_scale
302305 self ._image_guidance_scale = image_guidance_scale
303306
304307 device = self ._execution_device
305308
306- if ip_adapter_image is not None :
307- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
308- image_embeds , negative_image_embeds = self .encode_image (
309- ip_adapter_image , device , num_images_per_prompt , output_hidden_state
310- )
311- if self .do_classifier_free_guidance :
312- image_embeds = torch .cat ([image_embeds , negative_image_embeds , negative_image_embeds ])
313-
314309 if image is None :
315310 raise ValueError ("`image` input cannot be undefined." )
316311
@@ -335,6 +330,14 @@ def __call__(
335330 negative_prompt_embeds = negative_prompt_embeds ,
336331 )
337332
333+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
334+ image_embeds = self .prepare_ip_adapter_image_embeds (
335+ ip_adapter_image ,
336+ ip_adapter_image_embeds ,
337+ device ,
338+ batch_size * num_images_per_prompt ,
339+ self .do_classifier_free_guidance ,
340+ )
338341 # 3. Preprocess image
339342 image = self .image_processor .preprocess (image )
340343
@@ -635,6 +638,65 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
635638
636639 return image_embeds , uncond_image_embeds
637640
641+ def prepare_ip_adapter_image_embeds (
642+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
643+ ):
644+ if ip_adapter_image_embeds is None :
645+ if not isinstance (ip_adapter_image , list ):
646+ ip_adapter_image = [ip_adapter_image ]
647+
648+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
649+ raise ValueError (
650+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
651+ )
652+
653+ image_embeds = []
654+ for single_ip_adapter_image , image_proj_layer in zip (
655+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
656+ ):
657+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
658+ single_image_embeds , single_negative_image_embeds = self .encode_image (
659+ single_ip_adapter_image , device , 1 , output_hidden_state
660+ )
661+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
662+ single_negative_image_embeds = torch .stack (
663+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
664+ )
665+
666+ if do_classifier_free_guidance :
667+ single_image_embeds = torch .cat (
668+ [single_image_embeds , single_negative_image_embeds , single_negative_image_embeds ]
669+ )
670+ single_image_embeds = single_image_embeds .to (device )
671+
672+ image_embeds .append (single_image_embeds )
673+ else :
674+ repeat_dims = [1 ]
675+ image_embeds = []
676+ for single_image_embeds in ip_adapter_image_embeds :
677+ if do_classifier_free_guidance :
678+ (
679+ single_image_embeds ,
680+ single_negative_image_embeds ,
681+ single_negative_image_embeds ,
682+ ) = single_image_embeds .chunk (3 )
683+ single_image_embeds = single_image_embeds .repeat (
684+ num_images_per_prompt , * (repeat_dims * len (single_image_embeds .shape [1 :]))
685+ )
686+ single_negative_image_embeds = single_negative_image_embeds .repeat (
687+ num_images_per_prompt , * (repeat_dims * len (single_negative_image_embeds .shape [1 :]))
688+ )
689+ single_image_embeds = torch .cat (
690+ [single_image_embeds , single_negative_image_embeds , single_negative_image_embeds ]
691+ )
692+ else :
693+ single_image_embeds = single_image_embeds .repeat (
694+ num_images_per_prompt , * (repeat_dims * len (single_image_embeds .shape [1 :]))
695+ )
696+ image_embeds .append (single_image_embeds )
697+
698+ return image_embeds
699+
638700 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
639701 def run_safety_checker (self , image , device , dtype ):
640702 if self .safety_checker is None :
@@ -687,6 +749,8 @@ def check_inputs(
687749 negative_prompt = None ,
688750 prompt_embeds = None ,
689751 negative_prompt_embeds = None ,
752+ ip_adapter_image = None ,
753+ ip_adapter_image_embeds = None ,
690754 callback_on_step_end_tensor_inputs = None ,
691755 ):
692756 if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
@@ -728,6 +792,21 @@ def check_inputs(
728792 f" { negative_prompt_embeds .shape } ."
729793 )
730794
795+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
796+ raise ValueError (
797+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
798+ )
799+
800+ if ip_adapter_image_embeds is not None :
801+ if not isinstance (ip_adapter_image_embeds , list ):
802+ raise ValueError (
803+ f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
804+ )
805+ elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
806+ raise ValueError (
807+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
808+ )
809+
731810 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
732811 def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
733812 shape = (
0 commit comments