Skip to content

Commit 9fa7f41

Browse files
authored
feat: add taehv support for Wan/Qwen (#937)
1 parent a23262d commit 9fa7f41

File tree

5 files changed

+432
-26
lines changed

5 files changed

+432
-26
lines changed

examples/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Context Options:
3131
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
3232
--vae <string> path to standalone vae model
3333
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
34+
--tae <string> alias of --taesd
3435
--control-net <string> path to control net model
3536
--embd-dir <string> embeddings directory
3637
--lora-model-dir <string> lora model directory

examples/common/common.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,10 @@ struct SDContextParams {
406406
"--taesd",
407407
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
408408
&taesd_path},
409+
{"",
410+
"--tae",
411+
"alias of --taesd",
412+
&taesd_path},
409413
{"",
410414
"--control-net",
411415
"path to control net model",

examples/server/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Context Options:
2424
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model
2525
--vae <string> path to standalone vae model
2626
--taesd <string> path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
27+
--tae <string> alias of --taesd
2728
--control-net <string> path to control net model
2829
--embd-dir <string> embeddings directory
2930
--lora-model-dir <string> lora model directory

stable-diffusion.cpp

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -562,14 +562,27 @@ class StableDiffusionGGML {
562562
}
563563

564564
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
565-
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
566-
offload_params_to_cpu,
567-
tensor_storage_map,
568-
"first_stage_model",
569-
vae_decode_only,
570-
version);
571-
first_stage_model->alloc_params_buffer();
572-
first_stage_model->get_param_tensors(tensors, "first_stage_model");
565+
if (!use_tiny_autoencoder) {
566+
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
567+
offload_params_to_cpu,
568+
tensor_storage_map,
569+
"first_stage_model",
570+
vae_decode_only,
571+
version);
572+
first_stage_model->alloc_params_buffer();
573+
first_stage_model->get_param_tensors(tensors, "first_stage_model");
574+
} else {
575+
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
576+
offload_params_to_cpu,
577+
tensor_storage_map,
578+
"decoder",
579+
vae_decode_only,
580+
version);
581+
if (sd_ctx_params->vae_conv_direct) {
582+
LOG_INFO("Using Conv2d direct in the tae model");
583+
tae_first_stage->set_conv2d_direct_enabled(true);
584+
}
585+
}
573586
} else if (version == VERSION_CHROMA_RADIANCE) {
574587
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
575588
offload_params_to_cpu);
@@ -596,14 +609,13 @@ class StableDiffusionGGML {
596609
}
597610
first_stage_model->alloc_params_buffer();
598611
first_stage_model->get_param_tensors(tensors, "first_stage_model");
599-
}
600-
if (use_tiny_autoencoder) {
601-
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
602-
offload_params_to_cpu,
603-
tensor_storage_map,
604-
"decoder.layers",
605-
vae_decode_only,
606-
version);
612+
} else if (use_tiny_autoencoder) {
613+
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
614+
offload_params_to_cpu,
615+
tensor_storage_map,
616+
"decoder.layers",
617+
vae_decode_only,
618+
version);
607619
if (sd_ctx_params->vae_conv_direct) {
608620
LOG_INFO("Using Conv2d direct in the tae model");
609621
tae_first_stage->set_conv2d_direct_enabled(true);
@@ -3614,7 +3626,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36143626
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
36153627
ggml_set_f32(denoise_mask, 1.f);
36163628

3617-
sd_ctx->sd->process_latent_out(init_latent);
3629+
if (!sd_ctx->sd->use_tiny_autoencoder)
3630+
sd_ctx->sd->process_latent_out(init_latent);
36183631

36193632
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
36203633
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
@@ -3624,7 +3637,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36243637
}
36253638
});
36263639

3627-
sd_ctx->sd->process_latent_in(init_latent);
3640+
if (!sd_ctx->sd->use_tiny_autoencoder)
3641+
sd_ctx->sd->process_latent_in(init_latent);
36283642

36293643
int64_t t2 = ggml_time_ms();
36303644
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@@ -3847,7 +3861,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38473861
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
38483862
int64_t t5 = ggml_time_ms();
38493863
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
3850-
if (sd_ctx->sd->free_params_immediately) {
3864+
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
38513865
sd_ctx->sd->first_stage_model->free_params_buffer();
38523866
}
38533867

0 commit comments

Comments
 (0)