From 42492941370e0aaf31dce1df29024ad2e07c81f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 5 Dec 2025 20:47:23 +0100 Subject: [PATCH 01/12] Support LongCat Image model --- flux.hpp | 44 +++++++++++++++++++++------- ggml_extend.hpp | 69 ++++++++++++++++++++++++++++++++++++++++++++ model.cpp | 29 ++++++++++++------- model.h | 11 ++++++- name_conversion.cpp | 8 ++++- stable-diffusion.cpp | 56 +++++++++++++++++++++-------------- 6 files changed, 173 insertions(+), 44 deletions(-) diff --git a/flux.hpp b/flux.hpp index 1df2874ae..7cd63d78f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -90,10 +90,15 @@ namespace Flux { SelfAttention(int64_t dim, int64_t num_heads = 8, bool qkv_bias = false, - bool proj_bias = true) + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + if(diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + } blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } @@ -258,7 +263,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; @@ -266,7 +272,7 @@ namespace Flux { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -279,7 +285,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -421,6 +427,7 @@ namespace Flux { bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; + bool diffusers_style = false; public: SingleStreamBlock(int64_t hidden_size, @@ -432,7 +439,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; @@ -444,8 +452,11 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + if (diffusers_style) { + blocks["linear1"] = std::shared_ptr(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias)); + } else { + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + } blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); @@ -772,6 +783,7 @@ namespace Flux { bool use_mlp_silu_act = false; float ref_index_scale = 1.f; ChromaRadianceParams chroma_radiance_params; + bool diffusers_style = false; }; struct Flux : public GGMLBlock { @@ -817,7 +829,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -830,7 +843,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } if (params.version == VERSION_CHROMA_RADIANCE) { @@ -1281,6 +1295,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; @@ -1290,6 +1307,9 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) { + flux_params.diffusers_style = true; + } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma flux_params.is_chroma = true; @@ -1319,6 +1339,10 @@ namespace Flux { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + if (flux_params.diffusers_style) { + LOG_INFO("Using diffusers-style naming"); + } + flux = Flux(flux_params); flux.init(params_ctx, tensor_storage_map, prefix); } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 5024eb911..57b0fff80 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2173,6 +2173,75 @@ class Linear : public UnaryBlock { } }; +class SplitLinear : public Linear { +protected: + int64_t in_features; + std::vector out_features_vec; + bool bias; + bool force_f32; + bool force_prec_f32; + float scale; + std::string prefix; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + // most likely same type as the first weight + params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]); + } + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]); + } + } + } + +public: + SplitLinear(int64_t in_features, + std::vector out_features_vec, + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false, + float scale = 1.f) + : Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale), + in_features(in_features), + out_features_vec(out_features_vec), + bias(bias), + force_f32(force_f32), + force_prec_f32(force_prec_f32), + scale(scale) {} + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + struct ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + // concat all weights and biases together + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } + return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + } +}; + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (allow_types.find(wtype) != allow_types.end()) { diff --git a/model.cpp b/model.cpp index 0480efefb..135a2108c 100644 --- a/model.cpp +++ b/model.cpp @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() { for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; - } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { diff --git a/model.h b/model.h index d38aee1c1..27af3d9c7 100644 --- a/model.h +++ b/model.h @@ -46,6 +46,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index 8b521486d..1a37dd25c 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef851247..73f832f73 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,6 +47,7 @@ const char* model_version_to_str[] = { "Flux.2", "Z-Image", "Ovis Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -372,7 +373,7 @@ class StableDiffusionGGML { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -400,8 +401,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -449,10 +450,23 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -461,10 +475,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -493,20 +507,20 @@ class StableDiffusionGGML { "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (int i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -827,7 +841,7 @@ class StableDiffusionGGML { flow_shift = 3.f; } } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; if (flow_shift == INFINITY) { flow_shift = 1.0f; // TODO: validate @@ -1341,7 +1355,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { From 52ef50a7ce94dc329501aa30b10b32a8159b3060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 5 Dec 2025 20:58:53 +0100 Subject: [PATCH 02/12] temp fix cuda error on quant concat for splitlinear --- ggml_extend.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 57b0fff80..0fcbbb960 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2214,7 +2214,7 @@ class SplitLinear : public Linear { in_features(in_features), out_features_vec(out_features_vec), bias(bias), - force_f32(force_f32), + force_f32(true), force_prec_f32(force_prec_f32), scale(scale) {} From 7ba7febef2143ba32db0c9942d3e898070d4a010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 02:43:46 +0100 Subject: [PATCH 03/12] pre-patchify --- flux.hpp | 1 + stable-diffusion.cpp | 21 ++++++++++++++++----- vae.hpp | 4 ++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/flux.hpp b/flux.hpp index 7cd63d78f..758a3d578 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1298,6 +1298,7 @@ namespace Flux { } else if (sd_version_is_longcat(version)) { flux_params.context_in_dim = 3584; flux_params.vec_in_dim = 0; + flux_params.patch_size = 1; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 73f832f73..eed5b0d3c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -450,10 +450,10 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_longcat(version)) { bool enable_vision = false; cond_stage_model = std::make_shared(clip_backend, @@ -850,6 +850,9 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } + if(sd_version_is_longcat(version)) { + flow_shift = 3.0f; + } } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1338,6 +1341,12 @@ class StableDiffusionGGML { if (sd_version_is_flux2(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; + patch_sz = 2; + } + } else if (dim == 64) { + if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { + latent_rgb_proj = flux_latent_rgb_proj; + latent_rgb_bias = flux_latent_rgb_bias; patch_sz = 2; } } else if (dim == 48) { @@ -1904,7 +1913,7 @@ class StableDiffusionGGML { int vae_scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { vae_scale_factor = 16; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { vae_scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { vae_scale_factor = 1; @@ -1933,6 +1942,8 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_is_flux2(version)) { latent_channel = 128; + } else if (sd_version_is_longcat(version)) { + latent_channel = 64; } else { latent_channel = 16; } diff --git a/vae.hpp b/vae.hpp index ad5db1b57..740a5655b 100644 --- a/vae.hpp +++ b/vae.hpp @@ -553,7 +553,7 @@ class AutoencodingEngine : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -592,7 +592,7 @@ class AutoencodingEngine : public GGMLBlock { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p] From 1241323c4a4372adade74c2b49e0402c61330873 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 02:44:20 +0100 Subject: [PATCH 04/12] longcat rope ids --- conditioner.hpp | 11 +++++++++++ flux.hpp | 7 +++---- ggml_extend.hpp | 26 +++++++++++++++++--------- rope.hpp | 35 +++++++++++++++++++++++++---------- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 55e1502e8..33857eb3a 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1807,6 +1807,17 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else if (sd_version_is_longcat(version)) { + prompt_template_encode_start_idx = 36; + // prompt_template_encode_end_idx = 5; + + prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else { prompt_template_encode_start_idx = 34; diff --git a/flux.hpp b/flux.hpp index 758a3d578..d0be65bad 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1341,7 +1341,7 @@ namespace Flux { } if (flux_params.diffusers_style) { - LOG_INFO("Using diffusers-style naming"); + LOG_INFO("Using diffusers-style attention blocks"); } flux = Flux(flux_params); @@ -1455,7 +1455,6 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { txt_arange_dims = {1, 2}; } - pe_vec = Rope::gen_flux_pe(x->ne[1], x->ne[0], flux_params.patch_size, @@ -1466,9 +1465,9 @@ namespace Flux { increase_ref_index, flux_params.ref_index_scale, flux_params.theta, - flux_params.axes_dim); + flux_params.axes_dim, + sd_version_is_longcat(version)); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; - // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 0fcbbb960..3d5020768 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2214,7 +2214,7 @@ class SplitLinear : public Linear { in_features(in_features), out_features_vec(out_features_vec), bias(bias), - force_f32(true), + force_f32(force_f32), force_prec_f32(force_prec_f32), scale(scale) {} @@ -2224,21 +2224,29 @@ class SplitLinear : public Linear { if (bias) { b = params["bias"]; } - // concat all weights and biases together - for (int i = 1; i < out_features_vec.size(); i++) { - w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); - if (bias) { - b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); - } - } if (ctx->weight_adapter) { + // concat all weights and biases together so it runs in one linear layer + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } WeightAdapter::ForwardParams forward_params; forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; forward_params.linear.force_prec_f32 = force_prec_f32; forward_params.linear.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } - return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + for (int i = 1; i < out_features_vec.size(); i++) { + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0); + } + + return x0; } }; diff --git a/rope.hpp b/rope.hpp index 4abc51469..95def626f 100644 --- a/rope.hpp +++ b/rope.hpp @@ -84,7 +84,16 @@ namespace Rope { return txt_ids; } - __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, + __STATIC_INLINE__ std::vector> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) { + auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); + for (int i = 0; i < bs * context_len; i++) { + txt_ids[i][1] = (i % context_len); + txt_ids[i][2] = (i % context_len); + } + return txt_ids; + } + + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, int bs, @@ -94,7 +103,6 @@ namespace Rope { int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); @@ -169,13 +177,14 @@ namespace Rope { __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, int axes_dim_num, + int start_index, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale) { std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; - int index = 1; + int index = start_index; for (ggml_tensor* ref : ref_latents) { uint64_t h_offset = 0; uint64_t w_offset = 0; @@ -216,13 +225,17 @@ namespace Rope { std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + float ref_index_scale, + bool is_longcat) { + int start_index = is_longcat ? 1 : 0; + + auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); + int offset = is_longcat ? context_len : 0; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -239,7 +252,8 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool is_longcat) { std::vector> ids = gen_flux_ids(h, w, patch_size, @@ -249,7 +263,8 @@ namespace Rope { txt_arange_dims, ref_latents, increase_ref_index, - ref_index_scale); + ref_index_scale, + is_longcat); return embed_nd(ids, bs, theta, axes_dim); } @@ -274,7 +289,7 @@ namespace Rope { auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f); ids = concat_ids(ids, refs_ids, bs); } return ids; From 203d0539fe6e745b4d5fb2b7ccc1ab384a71ba45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 03:47:52 +0100 Subject: [PATCH 05/12] Fix diffusers_style detection --- flux.hpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/flux.hpp b/flux.hpp index d0be65bad..c500302c4 100644 --- a/flux.hpp +++ b/flux.hpp @@ -88,19 +88,19 @@ namespace Flux { public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool proj_bias = true, - bool diffusers_style = false) + int64_t num_heads = 8, + bool qkv_bias = false, + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - if(diffusers_style) { - blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + if (diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); } else { - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); } - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { @@ -782,8 +782,8 @@ namespace Flux { bool use_yak_mlp = false; bool use_mlp_silu_act = false; float ref_index_scale = 1.f; + bool diffusers_style = false; ChromaRadianceParams chroma_radiance_params; - bool diffusers_style = false; }; struct Flux : public GGMLBlock { @@ -1308,7 +1308,7 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } - if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) { + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) { flux_params.diffusers_style = true; } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { @@ -1466,9 +1466,9 @@ namespace Flux { flux_params.ref_index_scale, flux_params.theta, flux_params.axes_dim, - sd_version_is_longcat(version)); + sd_version_is_longcat(version)); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; From 37c5e3eca4fad326cd45413d0fbe18652cebebd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 16:05:58 +0100 Subject: [PATCH 06/12] Flux: simplify when patch_size is 1 --- flux.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/flux.hpp b/flux.hpp index c500302c4..fc3098780 100644 --- a/flux.hpp +++ b/flux.hpp @@ -891,6 +891,11 @@ namespace Flux { int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; + if (params.patch_size == 1) { + x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W] + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C] + return x; + } int64_t p = params.patch_size; int64_t h = H / params.patch_size; int64_t w = W / params.patch_size; @@ -925,6 +930,12 @@ namespace Flux { int64_t W = w * params.patch_size; int64_t p = params.patch_size; + if (params.patch_size == 1) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W] + return x; + } + GGML_ASSERT(C * p * p == x->ne[0]); x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] From a907fe28513b543ff948f7fe6b6f67caa2f1943f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 16:06:32 +0100 Subject: [PATCH 07/12] correct rope offset for image tokens stuff --- ggml_extend.hpp | 12 ++++++------ rope.hpp | 13 +++++++------ stable-diffusion.cpp | 6 +++++- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 3d5020768..1a630ac39 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2238,15 +2238,15 @@ class SplitLinear : public Linear { forward_params.linear.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } - auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); for (int i = 1; i < out_features_vec.size(); i++) { - auto wi = params["weight." + std::to_string(i)]; - auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; - auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); - x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0); + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0); } - return x0; + return out; } }; diff --git a/rope.hpp b/rope.hpp index 95def626f..0c18c0a02 100644 --- a/rope.hpp +++ b/rope.hpp @@ -180,10 +180,11 @@ namespace Rope { int start_index, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { + float ref_index_scale, + int base_offset = 0) { std::vector> ids; - uint64_t curr_h_offset = 0; - uint64_t curr_w_offset = 0; + uint64_t curr_h_offset = base_offset; + uint64_t curr_w_offset = base_offset; int index = start_index; for (ggml_tensor* ref : ref_latents) { uint64_t h_offset = 0; @@ -227,15 +228,15 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, bool is_longcat) { - int start_index = is_longcat ? 1 : 0; + int x_index = is_longcat ? 1 : 0; auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); int offset = is_longcat ? context_len : 0; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset); + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, offset); ids = concat_ids(ids, refs_ids, bs); } return ids; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index eed5b0d3c..1e8f04aa1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -456,6 +456,9 @@ class StableDiffusionGGML { sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_longcat(version)) { bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, @@ -850,7 +853,7 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } - if(sd_version_is_longcat(version)) { + if (sd_version_is_longcat(version)) { flow_shift = 3.0f; } } @@ -2244,6 +2247,7 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || + sd_version_is_longcat(version) || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { From fc8d85e1335c3497090132e0b2ced987c65af5fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 01:36:17 +0100 Subject: [PATCH 08/12] Fix token length --- conditioner.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conditioner.hpp b/conditioner.hpp index 33857eb3a..481ee78d6 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1698,7 +1698,7 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; - int max_length = 0; + int max_length = 0; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); @@ -1810,6 +1810,7 @@ struct LLMEmbedder : public Conditioner { } else if (sd_version_is_longcat(version)) { prompt_template_encode_start_idx = 36; // prompt_template_encode_end_idx = 5; + max_length = 512; prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; From 9f225e4e631cf99f500a31f56e0cbede3aa5ec2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 02:28:29 +0100 Subject: [PATCH 09/12] Split quoted text into character-level tokens remove debug logs --- conditioner.hpp | 105 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 29 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 481ee78d6..7faa84b58 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1648,46 +1648,91 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector> tokenize( + std::string text, + std::pair attn_range, + size_t max_length = 0, + bool padding = false, + bool spell_quotes = false) { std::vector> parsed_attention; parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); + auto new_parsed_attention = parse_prompt_attention( + text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert( + parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - { - std::stringstream ss; - ss << "["; - for (const auto& item : parsed_attention) { - ss << "['" << item.first << "', " << item.second << "], "; - } - ss << "]"; - LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - } + + // { + // std::stringstream ss; + // ss << '['; + // for (const auto& item : parsed_attention) { + // ss << "['" << item.first << "', " << item.second << "], "; + // } + // ss << ']'; + // LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + // } std::vector tokens; std::vector weights; + for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); - tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); - weights.insert(weights.end(), curr_tokens.size(), curr_weight); - } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + if (spell_quotes) { + std::vector parts; + bool in_quote = false; + std::string current_part; - // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; - // } - // std::cout << std::endl; + for (char c : curr_text) { + if (c == '\'') { + if (!current_part.empty()) { + parts.push_back(current_part); + current_part.clear(); + } + in_quote = !in_quote; + } else { + current_part += c; + if (in_quote && current_part.size() == 1) { + parts.push_back(current_part); + current_part.clear(); + } + } + } + if (!current_part.empty()) { + parts.push_back(current_part); + } + for (const auto& part : parts) { + if (part.empty()) + continue; + if (part[0] == '\'' && part.back() == '\'') { + std::string quoted_content = part.substr(1, part.size() - 2); + for (char ch : quoted_content) { + std::string char_str(1, ch); + std::vector char_tokens = tokenizer->tokenize(char_str, nullptr); + tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); + weights.insert(weights.end(), char_tokens.size(), curr_weight); + } + } else { + std::vector part_tokens = tokenizer->tokenize(part, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } + } + } else { + std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + } + + tokenizer->pad_tokens(tokens, weights, max_length, padding); return {tokens, weights}; } @@ -1698,7 +1743,8 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; - int max_length = 0; + int max_length = 0; + bool spell_quotes = false; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); @@ -1810,7 +1856,8 @@ struct LLMEmbedder : public Conditioner { } else if (sd_version_is_longcat(version)) { prompt_template_encode_start_idx = 36; // prompt_template_encode_end_idx = 5; - max_length = 512; + max_length = 512; + spell_quotes = true; prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; @@ -1831,7 +1878,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); From c044a406c5411869e3545aba6c3120f4127f1d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 13:43:37 +0100 Subject: [PATCH 10/12] support longcat-image-edit Fix base rope offset for ref images --- conditioner.hpp | 179 +++++++++++++++++++++++++++++++++--------------- rope.hpp | 13 ++-- 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 7faa84b58..9b9cca9e1 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1690,7 +1690,7 @@ struct LLMEmbedder : public Conditioner { std::string current_part; for (char c : curr_text) { - if (c == '\'') { + if (c == '"') { if (!current_part.empty()) { parts.push_back(current_part); current_part.clear(); @@ -1711,7 +1711,7 @@ struct LLMEmbedder : public Conditioner { for (const auto& part : parts) { if (part.empty()) continue; - if (part[0] == '\'' && part.back() == '\'') { + if (part[0] == '"' && part.back() == '"') { std::string quoted_content = part.substr(1, part.size() - 2); for (char ch : quoted_content) { std::string char_str(1, ch); @@ -1747,68 +1747,139 @@ struct LLMEmbedder : public Conditioner { bool spell_quotes = false; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor)) * factor; - int w_bar = static_cast(std::round(width / factor)) * factor; - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + if (sd_version_is_longcat(version)) { + LOG_INFO("LongCatEditPipeline"); + prompt_template_encode_start_idx = 67; + // prompt_template_encode_end_idx = 5; + int image_embed_idx = 36 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + + // Only one image is officicially supported by the model, not sure how it handles multiple images + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "<|vision_start|>"; + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + max_length = 512; + spell_quotes = true; + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + + } else { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + image_embed->ne[1] + 6; + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - img_prompt += "<|vision_end|>"; - } - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; diff --git a/rope.hpp b/rope.hpp index 0c18c0a02..7666b02ee 100644 --- a/rope.hpp +++ b/rope.hpp @@ -93,7 +93,7 @@ namespace Rope { return txt_ids; } - __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, int bs, @@ -107,7 +107,6 @@ namespace Rope { std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); - for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { img_ids[i * w_len + j][0] = index; @@ -181,10 +180,10 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, - int base_offset = 0) { + int base_offset = 0) { std::vector> ids; - uint64_t curr_h_offset = base_offset; - uint64_t curr_w_offset = base_offset; + uint64_t curr_h_offset = 0; + uint64_t curr_w_offset = 0; int index = start_index; for (ggml_tensor* ref : ref_latents) { uint64_t h_offset = 0; @@ -203,8 +202,8 @@ namespace Rope { bs, axes_dim_num, static_cast(index * ref_index_scale), - h_offset, - w_offset); + h_offset + base_offset, + w_offset + base_offset); ids = concat_ids(ids, ref_ids, bs); if (increase_ref_index) { From fd032bcfa7eda95033e053c485f894d1caaa7eb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 12 Dec 2025 02:10:46 +0100 Subject: [PATCH 11/12] Split quotes by utf8 characters rather than individual char --- conditioner.hpp | 87 +++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 9b9cca9e1..8c5cdcf0c 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1648,6 +1648,19 @@ struct LLMEmbedder : public Conditioner { } } + size_t get_utf8_char_len(char c) { + unsigned char uc = static_cast(c); + if ((uc & 0x80) == 0) + return 1; // ASCII (1 byte) + if ((uc & 0xE0) == 0xC0) + return 2; // 2-byte char + if ((uc & 0xF0) == 0xE0) + return 3; // 3-byte char (Common for Chinese/Japanese) + if ((uc & 0xF8) == 0xF0) + return 4; // 4-byte char (Emojis, etc.) + return 1; // Fallback (should not happen in valid UTF-8) + } + std::tuple, std::vector> tokenize( std::string text, std::pair attn_range, @@ -1667,16 +1680,6 @@ struct LLMEmbedder : public Conditioner { } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - // { - // std::stringstream ss; - // ss << '['; - // for (const auto& item : parsed_attention) { - // ss << "['" << item.first << "', " << item.second << "], "; - // } - // ss << ']'; - // LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - // } - std::vector tokens; std::vector weights; @@ -1685,46 +1688,47 @@ struct LLMEmbedder : public Conditioner { float curr_weight = item.second; if (spell_quotes) { - std::vector parts; + std::string buffer; bool in_quote = false; - std::string current_part; - for (char c : curr_text) { - if (c == '"') { - if (!current_part.empty()) { - parts.push_back(current_part); - current_part.clear(); + size_t i = 0; + while (i < curr_text.size()) { + // utf8 character can be 1-4 char + size_t char_len = get_utf8_char_len(curr_text[i]); + + // Safety check to prevent reading past end of string + if (i + char_len > curr_text.size()) { + char_len = curr_text.size() - i; + } + std::string uchar = curr_text.substr(i, char_len); + i += char_len; + + if (uchar == "\"") { + buffer += uchar; + // If we were accumulating normal text, flush it now + if (!in_quote) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + buffer.clear(); } in_quote = !in_quote; } else { - current_part += c; - if (in_quote && current_part.size() == 1) { - parts.push_back(current_part); - current_part.clear(); - } - } - } - if (!current_part.empty()) { - parts.push_back(current_part); - } - - for (const auto& part : parts) { - if (part.empty()) - continue; - if (part[0] == '"' && part.back() == '"') { - std::string quoted_content = part.substr(1, part.size() - 2); - for (char ch : quoted_content) { - std::string char_str(1, ch); - std::vector char_tokens = tokenizer->tokenize(char_str, nullptr); + if (in_quote) { + std::vector char_tokens = tokenizer->tokenize(uchar, nullptr); tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); weights.insert(weights.end(), char_tokens.size(), curr_weight); + } else { + buffer += uchar; } - } else { - std::vector part_tokens = tokenizer->tokenize(part, nullptr); - tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); - weights.insert(weights.end(), part_tokens.size(), curr_weight); } } + + if (!buffer.empty()) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } } else { std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); @@ -1751,14 +1755,13 @@ struct LLMEmbedder : public Conditioner { LOG_INFO("LongCatEditPipeline"); prompt_template_encode_start_idx = 67; // prompt_template_encode_end_idx = 5; - int image_embed_idx = 36 + 6; + int image_embed_idx = 36 + 6; int min_pixels = 384 * 384; int max_pixels = 560 * 560; std::string placeholder = "<|image_pad|>"; std::string img_prompt; - // Only one image is officicially supported by the model, not sure how it handles multiple images for (int i = 0; i < conditioner_params.ref_images.size(); i++) { sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); From 196bb895fd0af69ff424f86252ce73f8b02bef7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 12 Dec 2025 02:39:09 +0100 Subject: [PATCH 12/12] patch size consistent with Flux1 --- flux.hpp | 1 - stable-diffusion.cpp | 11 +---------- vae.hpp | 4 ++-- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/flux.hpp b/flux.hpp index fc3098780..df3c4c8de 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1309,7 +1309,6 @@ namespace Flux { } else if (sd_version_is_longcat(version)) { flux_params.context_in_dim = 3584; flux_params.vec_in_dim = 0; - flux_params.patch_size = 1; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1e8f04aa1..f89e4268a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1346,12 +1346,6 @@ class StableDiffusionGGML { latent_rgb_bias = flux2_latent_rgb_bias; patch_sz = 2; } - } else if (dim == 64) { - if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { - latent_rgb_proj = flux_latent_rgb_proj; - latent_rgb_bias = flux_latent_rgb_bias; - patch_sz = 2; - } } else if (dim == 48) { if (sd_version_is_wan(version)) { latent_rgb_proj = wan_22_latent_rgb_proj; @@ -1916,7 +1910,7 @@ class StableDiffusionGGML { int vae_scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { vae_scale_factor = 16; - } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + } else if (sd_version_is_flux2(version)) { vae_scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { vae_scale_factor = 1; @@ -1945,8 +1939,6 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_is_flux2(version)) { latent_channel = 128; - } else if (sd_version_is_longcat(version)) { - latent_channel = 64; } else { latent_channel = 16; } @@ -2247,7 +2239,6 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || - sd_version_is_longcat(version) || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { diff --git a/vae.hpp b/vae.hpp index 740a5655b..ad5db1b57 100644 --- a/vae.hpp +++ b/vae.hpp @@ -553,7 +553,7 @@ class AutoencodingEngine : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + if (sd_version_is_flux2(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -592,7 +592,7 @@ class AutoencodingEngine : public GGMLBlock { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + if (sd_version_is_flux2(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p]