@@ -562,14 +562,27 @@ class StableDiffusionGGML {
562562 }
563563
564564 if (sd_version_is_wan (version) || sd_version_is_qwen_image (version)) {
565- first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
566- offload_params_to_cpu,
567- tensor_storage_map,
568- " first_stage_model" ,
569- vae_decode_only,
570- version);
571- first_stage_model->alloc_params_buffer ();
572- first_stage_model->get_param_tensors (tensors, " first_stage_model" );
565+ if (!use_tiny_autoencoder) {
566+ first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
567+ offload_params_to_cpu,
568+ tensor_storage_map,
569+ " first_stage_model" ,
570+ vae_decode_only,
571+ version);
572+ first_stage_model->alloc_params_buffer ();
573+ first_stage_model->get_param_tensors (tensors, " first_stage_model" );
574+ } else {
575+ tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
576+ offload_params_to_cpu,
577+ tensor_storage_map,
578+ " decoder" ,
579+ vae_decode_only,
580+ version);
581+ if (sd_ctx_params->vae_conv_direct ) {
582+ LOG_INFO (" Using Conv2d direct in the tae model" );
583+ tae_first_stage->set_conv2d_direct_enabled (true );
584+ }
585+ }
573586 } else if (version == VERSION_CHROMA_RADIANCE) {
574587 first_stage_model = std::make_shared<FakeVAE>(vae_backend,
575588 offload_params_to_cpu);
@@ -596,14 +609,13 @@ class StableDiffusionGGML {
596609 }
597610 first_stage_model->alloc_params_buffer ();
598611 first_stage_model->get_param_tensors (tensors, " first_stage_model" );
599- }
600- if (use_tiny_autoencoder) {
601- tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
602- offload_params_to_cpu,
603- tensor_storage_map,
604- " decoder.layers" ,
605- vae_decode_only,
606- version);
612+ } else if (use_tiny_autoencoder) {
613+ tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
614+ offload_params_to_cpu,
615+ tensor_storage_map,
616+ " decoder.layers" ,
617+ vae_decode_only,
618+ version);
607619 if (sd_ctx_params->vae_conv_direct ) {
608620 LOG_INFO (" Using Conv2d direct in the tae model" );
609621 tae_first_stage->set_conv2d_direct_enabled (true );
@@ -3614,7 +3626,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36143626 denoise_mask = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
36153627 ggml_set_f32 (denoise_mask, 1 .f );
36163628
3617- sd_ctx->sd ->process_latent_out (init_latent);
3629+ if (!sd_ctx->sd ->use_tiny_autoencoder )
3630+ sd_ctx->sd ->process_latent_out (init_latent);
36183631
36193632 ggml_ext_tensor_iter (init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
36203633 float value = ggml_ext_tensor_get_f32 (t, i0, i1, i2, i3);
@@ -3624,7 +3637,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36243637 }
36253638 });
36263639
3627- sd_ctx->sd ->process_latent_in (init_latent);
3640+ if (!sd_ctx->sd ->use_tiny_autoencoder )
3641+ sd_ctx->sd ->process_latent_in (init_latent);
36283642
36293643 int64_t t2 = ggml_time_ms ();
36303644 LOG_INFO (" encode_first_stage completed, taking %" PRId64 " ms" , t2 - t1);
@@ -3847,7 +3861,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38473861 struct ggml_tensor * vid = sd_ctx->sd ->decode_first_stage (work_ctx, final_latent, true );
38483862 int64_t t5 = ggml_time_ms ();
38493863 LOG_INFO (" decode_first_stage completed, taking %.2fs" , (t5 - t4) * 1 .0f / 1000 );
3850- if (sd_ctx->sd ->free_params_immediately ) {
3864+ if (sd_ctx->sd ->free_params_immediately && !sd_ctx-> sd -> use_tiny_autoencoder ) {
38513865 sd_ctx->sd ->first_stage_model ->free_params_buffer ();
38523866 }
38533867
0 commit comments