1414
1515import inspect
1616from dataclasses import dataclass
17- from typing import Any , Callable , Dict , List , Optional , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import numpy as np
2020import torch
6666 ... custom_pipeline="pipeline_animatediff_controlnet",
6767 ... ).to(device="cuda", dtype=torch.float16)
6868 >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
69- ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
69+ ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
7070 ... )
7171 >>> pipe.enable_vae_slicing()
7272
8383 ... height=768,
8484 ... conditioning_frames=conditioning_frames,
8585 ... num_inference_steps=12,
86- ... ).frames[0]
86+ ... )
8787
8888 >>> from diffusers.utils import export_to_gif
8989 >>> export_to_gif(result.frames[0], "result.gif")
@@ -151,7 +151,7 @@ def __init__(
151151 tokenizer : CLIPTokenizer ,
152152 unet : UNet2DConditionModel ,
153153 motion_adapter : MotionAdapter ,
154- controlnet : Union [ControlNetModel , MultiControlNetModel ],
154+ controlnet : Union [ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ],
155155 scheduler : Union [
156156 DDIMScheduler ,
157157 PNDMScheduler ,
@@ -166,6 +166,9 @@ def __init__(
166166 super ().__init__ ()
167167 unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
168168
169+ if isinstance (controlnet , (list , tuple )):
170+ controlnet = MultiControlNetModel (controlnet )
171+
169172 self .register_modules (
170173 vae = vae ,
171174 text_encoder = text_encoder ,
@@ -488,6 +491,7 @@ def check_inputs(
488491 prompt ,
489492 height ,
490493 width ,
494+ num_frames ,
491495 callback_steps ,
492496 negative_prompt = None ,
493497 prompt_embeds = None ,
@@ -557,31 +561,21 @@ def check_inputs(
557561 or is_compiled
558562 and isinstance (self .controlnet ._orig_mod , ControlNetModel )
559563 ):
560- if isinstance (image , list ):
561- for image_ in image :
562- self .check_image (image_ , prompt , prompt_embeds )
563- else :
564- self .check_image (image , prompt , prompt_embeds )
564+ if not isinstance (image , list ):
565+ raise TypeError (f"For single controlnet, `image` must be of type `list` but got { type (image )} " )
566+ if len (image ) != num_frames :
567+ raise ValueError (f"Excepted image to have length { num_frames } but got { len (image )= } " )
565568 elif (
566569 isinstance (self .controlnet , MultiControlNetModel )
567570 or is_compiled
568571 and isinstance (self .controlnet ._orig_mod , MultiControlNetModel )
569572 ):
570- if not isinstance (image , list ):
571- raise TypeError ("For multiple controlnets: `image` must be type `list`" )
572-
573- # When `image` is a nested list:
574- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
575- elif any (isinstance (i , list ) for i in image ):
576- raise ValueError ("A single batch of multiple conditionings are supported at the moment." )
577- elif len (image ) != len (self .controlnet .nets ):
578- raise ValueError (
579- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got { len (image )} images and { len (self .controlnet .nets )} ControlNets."
580- )
581-
582- for control_ in image :
583- for image_ in control_ :
584- self .check_image (image_ , prompt , prompt_embeds )
573+ if not isinstance (image , list ) or not isinstance (image [0 ], list ):
574+ raise TypeError (f"For multiple controlnets: `image` must be type list of lists but got { type (image )= } " )
575+ if len (image [0 ]) != num_frames :
576+ raise ValueError (f"Expected length of image sublist as { num_frames } but got { len (image [0 ])= } " )
577+ if any (len (img ) != len (image [0 ]) for img in image ):
578+ raise ValueError ("All conditioning frame batches for multicontrolnet must be same size" )
585579 else :
586580 assert False
587581
@@ -913,6 +907,7 @@ def __call__(
913907 prompt = prompt ,
914908 height = height ,
915909 width = width ,
910+ num_frames = num_frames ,
916911 callback_steps = callback_steps ,
917912 negative_prompt = negative_prompt ,
918913 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
@@ -1000,9 +995,7 @@ def __call__(
1000995 do_classifier_free_guidance = self .do_classifier_free_guidance ,
1001996 guess_mode = guess_mode ,
1002997 )
1003-
1004998 cond_prepared_frames .append (prepared_frame )
1005-
1006999 conditioning_frames = cond_prepared_frames
10071000 else :
10081001 assert False
0 commit comments