diff --git a/conditioner.hpp b/conditioner.hpp index 55e1502e8..8c5cdcf0c 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1648,46 +1648,95 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, - size_t max_length = 0, - bool padding = false) { + size_t get_utf8_char_len(char c) { + unsigned char uc = static_cast(c); + if ((uc & 0x80) == 0) + return 1; // ASCII (1 byte) + if ((uc & 0xE0) == 0xC0) + return 2; // 2-byte char + if ((uc & 0xF0) == 0xE0) + return 3; // 3-byte char (Common for Chinese/Japanese) + if ((uc & 0xF8) == 0xF0) + return 4; // 4-byte char (Emojis, etc.) + return 1; // Fallback (should not happen in valid UTF-8) + } + + std::tuple, std::vector> tokenize( + std::string text, + std::pair attn_range, + size_t max_length = 0, + bool padding = false, + bool spell_quotes = false) { std::vector> parsed_attention; parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); + auto new_parsed_attention = parse_prompt_attention( + text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert( + parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - { - std::stringstream ss; - ss << "["; - for (const auto& item : parsed_attention) { - ss << "['" << item.first << "', " << item.second << "], "; - } - ss << "]"; - LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - } std::vector tokens; std::vector weights; + for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); - tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); - weights.insert(weights.end(), curr_tokens.size(), curr_weight); - } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + if (spell_quotes) { + std::string buffer; + bool in_quote = false; - // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; - // } - // std::cout << std::endl; + size_t i = 0; + while (i < curr_text.size()) { + // utf8 character can be 1-4 char + size_t char_len = get_utf8_char_len(curr_text[i]); + + // Safety check to prevent reading past end of string + if (i + char_len > curr_text.size()) { + char_len = curr_text.size() - i; + } + std::string uchar = curr_text.substr(i, char_len); + i += char_len; + + if (uchar == "\"") { + buffer += uchar; + // If we were accumulating normal text, flush it now + if (!in_quote) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + buffer.clear(); + } + in_quote = !in_quote; + } else { + if (in_quote) { + std::vector char_tokens = tokenizer->tokenize(uchar, nullptr); + tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); + weights.insert(weights.end(), char_tokens.size(), curr_weight); + } else { + buffer += uchar; + } + } + } + + if (!buffer.empty()) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } + } else { + std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + } + tokenizer->pad_tokens(tokens, weights, max_length, padding); return {tokens, weights}; } @@ -1699,70 +1748,141 @@ struct LLMEmbedder : public Conditioner { std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; int max_length = 0; + bool spell_quotes = false; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor)) * factor; - int w_bar = static_cast(std::round(width / factor)) * factor; - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + if (sd_version_is_longcat(version)) { + LOG_INFO("LongCatEditPipeline"); + prompt_template_encode_start_idx = 67; + // prompt_template_encode_end_idx = 5; + int image_embed_idx = 36 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + // Only one image is officicially supported by the model, not sure how it handles multiple images + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "<|vision_start|>"; + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + max_length = 512; + spell_quotes = true; + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + + } else { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + image_embed->ne[1] + 6; + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - img_prompt += "<|vision_end|>"; - } - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1807,6 +1927,19 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else if (sd_version_is_longcat(version)) { + prompt_template_encode_start_idx = 36; + // prompt_template_encode_end_idx = 5; + max_length = 512; + spell_quotes = true; + + prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else { prompt_template_encode_start_idx = 34; @@ -1819,7 +1952,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); diff --git a/flux.hpp b/flux.hpp index 1df2874ae..df3c4c8de 100644 --- a/flux.hpp +++ b/flux.hpp @@ -88,14 +88,19 @@ namespace Flux { public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool proj_bias = true) + int64_t num_heads = 8, + bool qkv_bias = false, + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); + if (diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + } + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { @@ -258,7 +263,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; @@ -266,7 +272,7 @@ namespace Flux { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -279,7 +285,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -421,6 +427,7 @@ namespace Flux { bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; + bool diffusers_style = false; public: SingleStreamBlock(int64_t hidden_size, @@ -432,7 +439,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; @@ -444,8 +452,11 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + if (diffusers_style) { + blocks["linear1"] = std::shared_ptr(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias)); + } else { + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + } blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); @@ -771,6 +782,7 @@ namespace Flux { bool use_yak_mlp = false; bool use_mlp_silu_act = false; float ref_index_scale = 1.f; + bool diffusers_style = false; ChromaRadianceParams chroma_radiance_params; }; @@ -817,7 +829,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -830,7 +843,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } if (params.version == VERSION_CHROMA_RADIANCE) { @@ -877,6 +891,11 @@ namespace Flux { int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; + if (params.patch_size == 1) { + x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W] + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C] + return x; + } int64_t p = params.patch_size; int64_t h = H / params.patch_size; int64_t w = W / params.patch_size; @@ -911,6 +930,12 @@ namespace Flux { int64_t W = w * params.patch_size; int64_t p = params.patch_size; + if (params.patch_size == 1) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W] + return x; + } + GGML_ASSERT(C * p * p == x->ne[0]); x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] @@ -1281,6 +1306,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; @@ -1290,6 +1318,9 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) { + flux_params.diffusers_style = true; + } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma flux_params.is_chroma = true; @@ -1319,6 +1350,10 @@ namespace Flux { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + if (flux_params.diffusers_style) { + LOG_INFO("Using diffusers-style attention blocks"); + } + flux = Flux(flux_params); flux.init(params_ctx, tensor_storage_map, prefix); } @@ -1430,7 +1465,6 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { txt_arange_dims = {1, 2}; } - pe_vec = Rope::gen_flux_pe(x->ne[1], x->ne[0], flux_params.patch_size, @@ -1441,10 +1475,10 @@ namespace Flux { increase_ref_index, flux_params.ref_index_scale, flux_params.theta, - flux_params.axes_dim); + flux_params.axes_dim, + sd_version_is_longcat(version)); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; - // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 5024eb911..1a630ac39 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2173,6 +2173,83 @@ class Linear : public UnaryBlock { } }; +class SplitLinear : public Linear { +protected: + int64_t in_features; + std::vector out_features_vec; + bool bias; + bool force_f32; + bool force_prec_f32; + float scale; + std::string prefix; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + // most likely same type as the first weight + params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]); + } + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]); + } + } + } + +public: + SplitLinear(int64_t in_features, + std::vector out_features_vec, + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false, + float scale = 1.f) + : Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale), + in_features(in_features), + out_features_vec(out_features_vec), + bias(bias), + force_f32(force_f32), + force_prec_f32(force_prec_f32), + scale(scale) {} + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + struct ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + if (ctx->weight_adapter) { + // concat all weights and biases together so it runs in one linear layer + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } + auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + for (int i = 1; i < out_features_vec.size(); i++) { + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0); + } + + return out; + } +}; + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (allow_types.find(wtype) != allow_types.end()) { diff --git a/model.cpp b/model.cpp index 0480efefb..135a2108c 100644 --- a/model.cpp +++ b/model.cpp @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() { for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; - } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { diff --git a/model.h b/model.h index d38aee1c1..27af3d9c7 100644 --- a/model.h +++ b/model.h @@ -46,6 +46,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index 8b521486d..1a37dd25c 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/rope.hpp b/rope.hpp index 4abc51469..7666b02ee 100644 --- a/rope.hpp +++ b/rope.hpp @@ -84,6 +84,15 @@ namespace Rope { return txt_ids; } + __STATIC_INLINE__ std::vector> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) { + auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); + for (int i = 0; i < bs * context_len; i++) { + txt_ids[i][1] = (i % context_len); + txt_ids[i][2] = (i % context_len); + } + return txt_ids; + } + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, @@ -94,12 +103,10 @@ namespace Rope { int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); - for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { img_ids[i * w_len + j][0] = index; @@ -169,13 +176,15 @@ namespace Rope { __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, int axes_dim_num, + int start_index, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { + float ref_index_scale, + int base_offset = 0) { std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; - int index = 1; + int index = start_index; for (ggml_tensor* ref : ref_latents) { uint64_t h_offset = 0; uint64_t w_offset = 0; @@ -193,8 +202,8 @@ namespace Rope { bs, axes_dim_num, static_cast(index * ref_index_scale), - h_offset, - w_offset); + h_offset + base_offset, + w_offset + base_offset); ids = concat_ids(ids, ref_ids, bs); if (increase_ref_index) { @@ -216,13 +225,17 @@ namespace Rope { std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + float ref_index_scale, + bool is_longcat) { + int x_index = is_longcat ? 1 : 0; + + auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); + int offset = is_longcat ? context_len : 0; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, offset); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -239,7 +252,8 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool is_longcat) { std::vector> ids = gen_flux_ids(h, w, patch_size, @@ -249,7 +263,8 @@ namespace Rope { txt_arange_dims, ref_latents, increase_ref_index, - ref_index_scale); + ref_index_scale, + is_longcat); return embed_nd(ids, bs, theta, axes_dim); } @@ -274,7 +289,7 @@ namespace Rope { auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f); ids = concat_ids(ids, refs_ids, bs); } return ids; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef851247..f89e4268a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,6 +47,7 @@ const char* model_version_to_str[] = { "Flux.2", "Z-Image", "Ovis Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -372,7 +373,7 @@ class StableDiffusionGGML { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -400,8 +401,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -449,10 +450,26 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -461,10 +478,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -493,20 +510,20 @@ class StableDiffusionGGML { "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (int i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -827,7 +844,7 @@ class StableDiffusionGGML { flow_shift = 3.f; } } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; if (flow_shift == INFINITY) { flow_shift = 1.0f; // TODO: validate @@ -836,6 +853,9 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } + if (sd_version_is_longcat(version)) { + flow_shift = 3.0f; + } } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1324,7 +1344,7 @@ class StableDiffusionGGML { if (sd_version_is_flux2(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; - patch_sz = 2; + patch_sz = 2; } } else if (dim == 48) { if (sd_version_is_wan(version)) { @@ -1341,7 +1361,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {