From ec5b158a864aa4f23d98d6c2d23eea81beafee8a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 2 Dec 2025 10:59:55 +0000 Subject: [PATCH 1/3] Support Fused MoE & Qwen3 GGUF MoE models --- README.md | 5 +- candle-core/src/op.rs | 2 +- candle-core/src/quantized/cuda.rs | 5 + candle-core/src/quantized/dummy_cuda.rs | 4 + candle-core/src/quantized/mod.rs | 18 + .../examples/quantized-qwen3-moe/README.md | 18 + .../examples/quantized-qwen3-moe/main.rs | 357 ++++ candle-examples/examples/qwen/README.md | 5 + candle-examples/examples/qwen/main.rs | 33 +- candle-kernels/Cargo.toml | 2 +- candle-kernels/build.rs | 48 +- candle-kernels/src/ffi.rs | 56 + candle-kernels/src/lib.rs | 2 + candle-kernels/src/moe/gguf.cuh | 1438 +++++++++++++++++ candle-kernels/src/moe/moe_gguf.cu | 216 +++ candle-kernels/src/moe/moe_utils.cuh | 188 +++ candle-kernels/src/moe/moe_wmma.cu | 283 ++++ candle-kernels/src/moe/moe_wmma_gguf.cu | 422 +++++ candle-nn/src/lib.rs | 1 + candle-nn/src/moe.rs | 349 ++++ candle-transformers/src/fused_moe.rs | 302 ++++ candle-transformers/src/lib.rs | 1 + candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_qwen3.rs | 18 +- .../src/models/quantized_qwen3_moe.rs | 465 ++++++ candle-transformers/src/models/qwen3_moe.rs | 37 +- 26 files changed, 4243 insertions(+), 33 deletions(-) create mode 100644 candle-examples/examples/quantized-qwen3-moe/README.md create mode 100644 candle-examples/examples/quantized-qwen3-moe/main.rs create mode 100644 candle-kernels/src/ffi.rs create mode 100644 candle-kernels/src/moe/gguf.cuh create mode 100644 candle-kernels/src/moe/moe_gguf.cu create mode 100644 candle-kernels/src/moe/moe_utils.cuh create mode 100644 candle-kernels/src/moe/moe_wmma.cu create mode 100644 candle-kernels/src/moe/moe_wmma_gguf.cu create mode 100644 candle-nn/src/moe.rs create mode 100644 candle-transformers/src/fused_moe.rs create mode 100644 candle-transformers/src/models/quantized_qwen3_moe.rs diff --git a/README.md b/README.md index 632afdd782..4a62a27593 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ We also provide some command line based examples using state of the art models: - [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of the LLaMA model using the same quantization techniques as [llama.cpp](https://github.com/ggerganov/llama.cpp). +- [Quantized Qwen3 MoE](./candle-examples/examples/quantized-qwen3-moe/): support gguf quantized models of Qwen3 MoE models. @@ -190,6 +191,7 @@ And then head over to - [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library. - [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible. - [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book. +- [`vllm.rs`](https://github.com/guoqingbao/vllm.rs): A minimalist vLLM implementation in Rust based on Candle. If you have an addition to this list, please submit a pull request. @@ -220,7 +222,7 @@ If you have an addition to this list, please submit a pull request. - Replit-code-v1.5-3B. - Bert. - Yi-6B and Yi-34B. - - Qwen1.5, Qwen1.5 MoE. + - Qwen1.5, Qwen1.5 MoE, Qwen3 MoE. - RWKV v5 and v6. - Quantized LLMs. - Llama 7b, 13b, 70b, as well as the chat and code variants. @@ -228,6 +230,7 @@ If you have an addition to this list, please submit a pull request. - Mixtral 8x7b. - Zephyr 7b a and b (Mistral-7b based). - OpenChat 3.5 (Mistral-7b based). + - Qwen3 MoE (16B-A3B, 32B-A3B) - Text to text. - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - Marian MT (Machine Translation). diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index a4d5d6cb97..3c3ffb1097 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1031,7 +1031,7 @@ impl UnaryOpT for Relu { pub struct BackpropOp(Option); impl BackpropOp { - pub(crate) fn none() -> Self { + pub fn none() -> Self { BackpropOp(None) } diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 3faf9f695f..563c9ce3db 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -742,6 +742,11 @@ impl QCudaStorage { .memcpy_dtoh(&self.data.inner.slice(..self.data.len), &mut out)?; Ok(out) } + + pub fn device_ptr(&self) -> Result<*const u8> { + use cudarc::driver::DevicePtr; + Ok(self.data.inner.device_ptr(self.data.inner.stream()).0 as *const u8) + } } impl QCudaStorage { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index 7194439a09..04f19f9fcb 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -54,6 +54,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn device_ptr(&self) -> Result<*const u8> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index cee8ccc2ad..7316d29871 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -239,6 +239,15 @@ impl QStorage { QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)), } } + + pub fn device_ptr(&self) -> Result<*const u8> { + match self { + QStorage::Cuda(storage) => storage.device_ptr(), + QStorage::Metal(_) | QStorage::Cpu(_) => { + crate::bail!("not implemented"); + } + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -670,6 +679,15 @@ impl QTensor { } } } + + pub fn device_ptr(&self) -> Result<*const u8> { + match &self.storage { + QStorage::Cuda(storage) => storage.device_ptr(), + QStorage::Metal(_) | QStorage::Cpu(_) => { + crate::bail!("not implemented"); + } + } + } } #[derive(Clone, Debug)] diff --git a/candle-examples/examples/quantized-qwen3-moe/README.md b/candle-examples/examples/quantized-qwen3-moe/README.md new file mode 100644 index 0000000000..8f82051a31 --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/README.md @@ -0,0 +1,18 @@ +# candle-quantized-qwen3-moe + +[Qwen3 MoE GGUF]((https://huggingface.co/unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF)) contains the GGUF format of Qwen3 32B MoE models, developed by Alibaba Cloud. + +## Running the example + +```bash +# Local GGUF file +cargo run --features cuda --example quantized-qwen3-moe --release -- --model /path/Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --prompt "Write a function to count prime numbers up to N." +``` + +Models available via `--which` argument: 16b_q2k, 16b_q4k, 16b_q6k, 16b_q80; 32b_q2k, 32b_q4k, 32b_q6k, 32b_q80; + +```bash +# Obtained from Huggingface +cargo run --features cuda --example quantized-qwen3-moe --release -- --which 32b_q4k --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?" +``` + diff --git a/candle-examples/examples/quantized-qwen3-moe/main.rs b/candle-examples/examples/quantized-qwen3-moe/main.rs new file mode 100644 index 0000000000..8fdfca39ef --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/main.rs @@ -0,0 +1,357 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::Tensor; +use candle::{quantized::gguf_file, DType}; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE as Qwen3_MoE; + +const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "16b_q2k")] + W3_16bQ2K, + #[value(name = "16b_q4k")] + W3_16bQ4K, + #[value(name = "16b_q6k")] + W3_16bQ6K, + #[value(name = "16b_q80")] + W3_16bQ80, + #[value(name = "32b_q2k")] + W3_32bQ2K, + #[value(name = "32b_q4k")] + W3_32bQ4K, + #[value(name = "32b_q6k")] + W3_32bQ6K, + #[value(name = "32b_q80")] + W3_32bQ80, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "16b_q2k")] + which: Which, + + #[arg(long, default_value = "bf16")] + dtype: String, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = "Qwen/Qwen3-30B-A3B-Instruct-2507"; + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename, revision) = match self.which { + Which::W3_16bQ2K => ( + "unsloth/Qwen3-16B-A3B-GGUF", + "Qwen3-16B-A3B-Q2_K.gguf", + "main", + ), + Which::W3_16bQ4K => ( + "unsloth/Qwen3-16B-A3B-GGUF", + "Qwen3-16B-A3B-Q4_K_M.gguf", + "main", + ), + Which::W3_16bQ6K => ( + "unsloth/Qwen3-16B-A3B-GGUF", + "Qwen3-16B-A3B-Q6_K.gguf", + "main", + ), + Which::W3_16bQ80 => ( + "unsloth/Qwen3-16B-A3B-GGUF", + "Qwen3-16B-A3B-Q8_0.gguf", + "main", + ), + + Which::W3_32bQ2K => ( + "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF", + "Qwen3-30B-A3B-Instruct-2507-Q2_K.gguf", + "main", + ), + Which::W3_32bQ4K => ( + "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF", + "Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf", + "main", + ), + Which::W3_32bQ6K => ( + "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF", + "Qwen3-30B-A3B-Instruct-2507-Q6_K.gguf", + "main", + ), + Which::W3_32bQ80 => ( + "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF", + "Qwen3-30B-A3B-Instruct-2507-Q8_0.gguf", + "main", + ), + }; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{size_in_bytes}B") + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let dtype = match args.dtype.as_str() { + "bf16" => DType::BF16, + "f16" => DType::F16, // Used for V100 + _ => { + panic!("Not supported dtype!") + } + }; + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + Qwen3_MoE::from_gguf(model, &mut file, &device, dtype)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + + let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"); + print!("formatted prompt: {}", &prompt_str); + + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + + let tokens = tokens.get_ids(); + + let to_sample = args.sample_len.saturating_sub(1); + + let mut all_tokens = vec![]; + + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + + let prompt_dt = start_prompt_processing.elapsed(); + + all_tokens.push(next_token); + + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); + + let start_post_prompt = std::time::Instant::now(); + + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md index d81cd6660a..92fa90e96a 100644 --- a/candle-examples/examples/qwen/README.md +++ b/candle-examples/examples/qwen/README.md @@ -50,3 +50,8 @@ $ cargo run --example qwen --features metal --release -- --prompt "Write a poem > Their beauty lives where hearts can fly. > 161 tokens generated (3.00 token/s) ``` + +```shell +# Local unquantized 32B MoE model (with Fused MoE kernel) (~80GB GPU memory) +cargo run --example qwen --features cuda --release -- --prompt "Write a poem about butterflies. ." --model "3-moe-a3b" --weight-path /path/Qwen3-30B-A3B-Instruct-2507 +``` \ No newline at end of file diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 796f3a1d1f..4c6a1e76d6 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -217,7 +217,7 @@ struct Args { tokenizer_file: Option, #[arg(long)] - weight_files: Option, + weight_path: Option, /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] @@ -288,15 +288,29 @@ fn main() -> Result<()> { RepoType::Model, args.revision, )); - let tokenizer_filename = match args.tokenizer_file { - Some(file) => std::path::PathBuf::from(file), - None => repo.get("tokenizer.json")?, + + let tokenizer_filename = match (args.weight_path.as_ref(), args.tokenizer_file.as_ref()) { + (Some(_), Some(file)) => std::path::PathBuf::from(file), + (None, Some(file)) => std::path::PathBuf::from(file), + (Some(path), None) => std::path::Path::new(path).join("tokenizer.json"), + (None, None) => repo.get("tokenizer.json")?, + }; + let config_file = match &args.weight_path { + Some(path) => std::path::Path::new(path).join("config.json"), + _ => repo.get("config.json")?, }; - let filenames = match args.weight_files { - Some(files) => files - .split(',') - .map(std::path::PathBuf::from) - .collect::>(), + + let filenames = match args.weight_path { + Some(path) => { + if std::path::Path::new(&path) + .join("model.safetensors.index.json") + .exists() + { + candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")? + } else { + vec!["model.safetensors".into()].into() + } + } None => match args.model { WhichModel::W0_5b | WhichModel::W2_0_5b @@ -324,7 +338,6 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config_file = repo.get("config.json")?; let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() || device.is_metal() { DType::BF16 diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 5ea1e07928..b571d05055 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -12,4 +12,4 @@ license = "MIT OR Apache-2.0" [dependencies] [build-dependencies] -bindgen_cuda = "0.1.5" +bindgen_cuda = { git = "https://github.com/guoqingbao/bindgen_cuda.git", version= "0.1.7" } diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index e1813cd010..035345f86c 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -7,10 +7,54 @@ fn main() { println!("cargo::rerun-if-changed=src/cuda_utils.cuh"); println!("cargo::rerun-if-changed=src/binary_op_macros.cuh"); + // Build for PTX let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let ptx_path = out_dir.join("ptx.rs"); - let builder = bindgen_cuda::Builder::default(); + let mut builder = bindgen_cuda::Builder::default() + .arg("--expt-relaxed-constexpr") + .arg("-std=c++17") + .arg("-O3") + .arg("--use_fast_math"); println!("cargo::warning={builder:?}"); let bindings = builder.build_ptx().unwrap(); - bindings.write(ptx_path).unwrap(); + bindings.write(&ptx_path).unwrap(); + + // Remove unwanted MOE PTX constants from ptx.rs + remove_lines(&ptx_path, &["MOE_GGUF", "MOE_WMMA", "MOE_WMMA_GGUF"]); + + // Build for FFI binding (must use custom bindgen_cuda, which supports simutanously build PTX and lib) + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let mut is_target_msvc = false; + if let Ok(target) = std::env::var("TARGET") { + if target.contains("msvc") { + is_target_msvc = true; + builder = builder.arg("-D_USE_MATH_DEFINES"); + } + } + + if !is_target_msvc { + builder = builder.arg("-Xcompiler").arg("-fPIC"); + } + + let builder = builder.kernel_paths(vec![ + "src/moe/moe_gguf.cu", + "src/moe/moe_wmma.cu", + "src/moe/moe_wmma_gguf.cu", + ]); + println!("cargo::warning={builder:?}"); + builder.build_lib(out_dir.join("libmoe.a")); + println!("cargo:rustc-link-search={}", out_dir.display()); + println!("cargo:rustc-link-lib=moe"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=stdc++"); +} + +fn remove_lines>(file: P, patterns: &[&str]) { + let content = std::fs::read_to_string(&file).unwrap(); + let filtered = content + .lines() + .filter(|line| !patterns.iter().any(|p| line.contains(p))) + .collect::>() + .join("\n"); + std::fs::write(file, filtered).unwrap(); } diff --git a/candle-kernels/src/ffi.rs b/candle-kernels/src/ffi.rs new file mode 100644 index 0000000000..ac50392721 --- /dev/null +++ b/candle-kernels/src/ffi.rs @@ -0,0 +1,56 @@ +use core::ffi::c_void; +#[allow(dead_code)] +extern "C" { + // for unquntized models + pub fn moe_gemm_wmma( + input: *const c_void, // device pointer [size_m, size_k] + weights: *const c_void, // device pointer [num_experts, size_n, size_k] + sorted_token_ids: *const i32, // device pointer [size_m] + expert_ids: *const i32, // host array [size_m] (expert id per sorted token) + topk_weights: *const f32, + output: *mut c_void, // device pointer [size_m, size_n] + expert_counts: *mut i32, // pre-allocated buffer [num_experts] + expert_offsets: *mut i32, // pre-allocated buffer [num_experts + 1] + num_experts: i32, + topk: i32, + size_m: i32, + size_n: i32, + size_k: i32, + dtype: i32, // 0=float16, 1=bf16 (for input/output) + is_prefill: bool, + stream: i64, + ); + + pub fn moe_gemm_gguf( + input: *const f32, // input [size_m, size_k] + weights: *const c_void, // weights [num_experts, size_n, size_k] + sorted_token_ids: *const i32, + expert_ids: *const i32, + topk_weights: *const f32, // device ptr or nullptr + output: *mut c_void, // float output [size_m, size_n] + num_experts: i32, + topk: i32, + size_m: i32, + size_n: i32, + size_k: i32, + gguf_dtype: i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weights) + stream: i64, + ); + + pub fn moe_gemm_gguf_prefill( + input: *const c_void, // input [size_m, size_k] + weights: *const u8, // weights [num_experts, size_n, size_k] + sorted_token_ids: *const i32, + expert_ids: *const i32, //must be host ptr + topk_weights: *const f32, // device ptr or nullptr + output: *mut c_void, // float output [size_m, size_n] + num_experts: i32, + topk: i32, + size_m: i32, + size_n: i32, + size_k: i32, + input_dtype: i32, // 0=f16, 1=bf16 (for inputs) + gguf_dtype: i32, //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weights) + stream: i64, + ); +} diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 9b66403475..cfc5732652 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -78,3 +78,5 @@ mdl!(REDUCE, Reduce); mdl!(SORT, Sort); mdl!(TERNARY, Ternary); mdl!(UNARY, Unary); + +pub mod ffi; diff --git a/candle-kernels/src/moe/gguf.cuh b/candle-kernels/src/moe/gguf.cuh new file mode 100644 index 0000000000..3e50e9e9e8 --- /dev/null +++ b/candle-kernels/src/moe/gguf.cuh @@ -0,0 +1,1438 @@ +// Kernels adapted from llama.cpp ggml-cuda.cu +// https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda.cu +#include "cuda_fp16.h" +#include "cuda_bf16.h" +#include + +#define GGML_UNUSED(x) (void)(x) +#define GGML_CUDA_ASSUME(x) + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#undef GGML_CUDA_F16 +#define GGML_CUDA_DMMV_X 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef uint16_t ggml_fp16_t; +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + + +#define WARP_SIZE 32 +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) + +#define CC_PASCAL 600 +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA1 (CC_OFFSET_AMD + 1010) +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define CC_RDNA3 (CC_OFFSET_AMD + 1100) + +static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + return __dp4a(a, b, c); +#else // __CUDA_ARCH__ >= MIN_CC_DP4A + const int8_t * a8 = (const int8_t *) &a; + const int8_t * b8 = (const int8_t *) &b; + return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 +#define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 +#define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 +#define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 +#define NWARPS_Q5_0_PASCAL 8 + +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 +#define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 +#define NWARPS_Q5_1_PASCAL 8 + +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 +#define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 +#define NWARPS_Q8_0_PASCAL 8 + +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 +#define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 +#define NWARPS_Q2_K_PASCAL 8 + +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 +#define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 +#define NWARPS_Q3_K_PASCAL 8 + +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 +#define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 +#define NWARPS_Q4_K_PASCAL 8 + +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 +#define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 +#define NWARPS_Q5_K_PASCAL 8 + +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 +#define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 +#define NWARPS_Q6_K_PASCAL 8 + + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +#ifdef GGML_QKK_64 +typedef struct { + half dm[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +#ifdef GGML_QKK_64 +typedef struct { + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + +// In llama.cpp this is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = ggml_cuda_dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = ggml_cuda_dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + +#else + + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = ggml_cuda_dp4a(ui2, v2 & 0x0f0f0f0f, ggml_cuda_dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = ggml_cuda_dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, ggml_cuda_dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = ggml_cuda_dp4a(0x01010101, ui2, ggml_cuda_dp4a(0x01010101, ui1, 0)); + const int dot4 = ggml_cuda_dp4a(0x01010101, ui4, ggml_cuda_dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (ggml_cuda_dp4a(ui1, v1, 0) * s[0] + ggml_cuda_dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (ggml_cuda_dp4a(ui3, v3, 0) * s[2] + ggml_cuda_dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; +#endif +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; + if (ix >= kx_padded) { + return; + } + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + const int i_padded = iy*kx_padded + ix; + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + if (iqs > 0) { + return; + } + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + +template +static __device__ __forceinline__ dst_t convert_from_half(half val) { + return val; +} + +template<> +__device__ __forceinline__ nv_bfloat16 convert_from_half(half val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __float2bfloat16(__half2float(val)); +#else + return __half2float(val); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +} + +template<> +__device__ __forceinline__ float convert_from_half(half val) { + return __half2float(val); +} + +template +inline __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const auto i = 0; //we only need dequant one block in each call + const block_q2_K * x = (const block_q2_K *) vx; + + const auto tid = threadIdx.x; + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + half dall = __low2half(x[i].dm); + half dmin = __high2half(x[i].dm); + y[l+ 0] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)))); + y[l+32] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)))); + y[l+64] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)))); + y[l+96] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)))); +} + +template +inline __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const auto i = 0; + const block_q3_K * x = (const block_q3_K *) vx; + + const auto r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + half d_all = x[i].d; + half dl = __hmul(d_all, __int2half_rn(us - 32)); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) { + y[l] = convert_from_half(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)))); + } +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +template +inline __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const auto i = 0; + + // assume 32 threads + const auto tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const half dall = __low2half(x[i].dm); + const half dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const half d1 = __hmul(dall, __int2half_rn(sc)); + const half m1 = __hmul(dmin, __int2half_rn(m)); + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const half d2 = __hmul(dall, __int2half_rn(sc)); + const half m2 = __hmul(dmin, __int2half_rn(m)); + for (int l = 0; l < n; ++l) { + y[l + 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1)); + y[l +32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2)); + } +} + +template +inline __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const auto i = 0; + + // assume 64 threads - this is very slightly better than the one below + const auto tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const half dall = __low2half(x[i].dm); + const half dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m)); + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m)); + + uint8_t hm = 1 << (2*il); + y[ 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1)); + y[ 1] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1)); + hm <<= 1; + y[32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2)); + y[33] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2)); +} + +template +inline __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const auto i = 0; + + // assume 64 threads - this is very slightly better than the one below + const auto tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const half d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = convert_from_half(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)))); + y[32] = convert_from_half(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)))); + y[64] = convert_from_half(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)))); + y[96] = convert_from_half(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)))); +} \ No newline at end of file diff --git a/candle-kernels/src/moe/moe_gguf.cu b/candle-kernels/src/moe/moe_gguf.cu new file mode 100644 index 0000000000..92704e6aad --- /dev/null +++ b/candle-kernels/src/moe/moe_gguf.cu @@ -0,0 +1,216 @@ +/** + * @brief CUDA kernel for Mixture-of-Experts (MoE) GEMM using GGUF quantized weights. + * + * This kernel performs a dot-product between quantized input tokens and + * quantized expert weight matrices, accumulating into float outputs. + * It supports per-token top-k weighting and tiling along the K dimension + * for efficient vectorized execution. + * + * Adapted from: https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_gemm_gguf.cu + */ +#include "gguf.cuh" +#include +#include +#include +#include +#include +#include +constexpr int MATRIX_ROW_PADDING = 512; + +constexpr int pad(int size, int padding) { + if (padding == 0) return size; // avoid divide-by-zero + return ((size + padding - 1) / padding) * padding; +} + +// Optional helper if you want ceil division explicitly +constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +namespace vllm_rs { + +/* +* Template Parameters: + * @tparam T Type of output elements (float, half, etc.) + * @tparam qk Quantization block size for weights (e.g., 32) + * @tparam qi Quantization block size for inputs (e.g., 32) + * @tparam block_q_t Type of quantized weight block (e.g., block_q8_0) + * @tparam vdr Vectorization factor (number of elements per lane) + * @tparam vec_dot_q_cuda Function for computing vectorized dot-product between quantized blocks + * + * Kernel Parameters: + * @param all_weights Pointer to all expert weight matrices, [num_experts, N, K] (quantized) + * @param all_inputs Pointer to all input tokens, [M_total, K] (quantized) + * @param sorted_token_ids Sorted token indices for batch processing + * @param expert_ids Expert ID for each token + * @param topk_weights Optional top-k MoE weight per token + * @param all_outputs Output buffer [M_total, N] (float) + * @param num_experts Number of experts + * @param topk Top-k experts selected per token + * @param size_m Number of tokens processed (M dimension) + * @param size_n Output feature dimension (N dimension) + * @param size_k Input feature dimension (K dimension) + * @param k_padded Padded K dimension for GGUF stride +*/ +template +__global__ void moe_gemm_gguf_kernel( + const void * __restrict__ all_weights, // [num_experts, N, K] (quantized) + const void * __restrict__ all_inputs, // [M_total, K] (quantized, M_total is total tokens) + const int32_t* __restrict__ sorted_token_ids,// [M] (M = num tokens processed) + const int32_t* __restrict__ expert_ids, // [M] + const float* __restrict__ topk_weights, // [M] + float * __restrict__ all_outputs, // [M_total, N] (float) + int num_experts, + int topk, + int size_m, int size_n, int size_k, // M, N, K are the logical dims + int k_padded // Padded K-dim for GGUF stride +) { + const int laneId = threadIdx.x; + const int wrapId = threadIdx.y; + const int nWraps = blockDim.y; + const int row = blockIdx.x * nWraps + wrapId; // This is the 'n' dimension (output row) + const int m_idx = blockIdx.y; // This is the 'm' dimension (token index) + + // This block computes the dot product for `output[token_id][n_row]` + + if (row >= size_n || m_idx >= size_m) { + return; + } + + // strides + const size_t weight_expert_stride_bytes = (size_t)(size_n * size_k) / qk * sizeof(block_q_t); + const size_t input_task_stride_bytes = (size_t)k_padded / QK8_1 * sizeof(block_q8_1); + const size_t output_task_stride_elems = (size_t)size_n; + + const int token_id = sorted_token_ids[m_idx]; // The *actual* row in input/output tensors + const int expert = expert_ids[m_idx]; + + // If expert is invalid, this token does not participate. + if (expert < 0 || expert >= num_experts) return; + + // Get the scaling factor for this token/expert pair + const float scale = (topk_weights) ? topk_weights[token_id] : 1.0f; + + const block_q_t * __restrict__ w_expert = + (const block_q_t *)((const char *)all_weights + (size_t)expert * weight_expert_stride_bytes); + + const int input_index = topk_weights ? token_id : (token_id / topk); + const block_q8_1 * __restrict__ y_ptr = + (const block_q8_1 *)((const char *)all_inputs + (size_t)input_index * input_task_stride_bytes); + + // dot-product tiling along k + const int blocks_per_row_x = size_k / qk; + const int blocks_per_iter = vdr * WARP_SIZE / qi; // no nwarps factor: one warp per batch item + + extern __shared__ int8_t shared_bytes[]; + block_q_t* w_shared_row = reinterpret_cast(shared_bytes); + for (int i = laneId; i < blocks_per_row_x; i += WARP_SIZE) { + w_shared_row[wrapId * blocks_per_row_x + i] = w_expert[row * blocks_per_row_x + i]; + } + __syncthreads(); + + // accumulators for rows_per_block rows (usually 1) + float acc = 0.0f; + + #pragma unroll + for (int kbx = laneId / (qi / vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk / QK8_1); + const int kqs = vdr * (laneId % (qi / vdr)); + acc += vec_dot_q_cuda( + // &w_expert[kbx + row * blocks_per_row_x], + &w_shared_row[wrapId * blocks_per_row_x + kbx], + &y_ptr[kby], + kqs); + } + + float v = warp_reduce_sum(acc) * scale; + if (laneId == 0) { + float * __restrict__ out_ptr = + all_outputs + ((size_t)token_id) * output_task_stride_elems; + out_ptr[row] = v; + } +} + +} + +#define LAUNCH_MOE_GGUF(qk, qi, block_q_t, vdr, vec_dot_q_cuda) \ + const int shared_bytes = size_k / qk * sizeof(block_q_t) * nWraps + 1024;\ + vllm_rs::moe_gemm_gguf_kernel \ + <<>>(\ + weights, y_q8_1,\ + sorted_token_ids, expert_ids, topk_weights,\ + outputs,\ + num_experts, topk,\ + size_m, size_n, size_k,\ + kx_padded\ + );\ + + +extern "C" void moe_gemm_gguf( + const float* inputs, //must be float + const void* weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const float* topk_weights, + float* outputs, + int num_experts, + int topk, + int size_m, // M (num tokens to process) + int size_n, // N (output dim) + int size_k, // K (input dim) + int quant_type, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5, + cudaStream_t stream +) { + const int QUANTIZE_BLOCK_SIZE = CUDA_QUANTIZE_BLOCK_SIZE; + const int kx_padded = pad(size_k, MATRIX_ROW_PADDING); + const int num_blocks = ceil_div(kx_padded, QUANTIZE_BLOCK_SIZE); + int m = topk_weights ? size_m : size_m / topk; + dim3 grid_dim_quant(num_blocks, m, 1); + dim3 block_dim_quant(QUANTIZE_BLOCK_SIZE, 1, 1); + int y_size_in_bytes = + m * (kx_padded / QK8_1 * sizeof(block_q8_1)); + void* y_q8_1 = nullptr; + cudaMallocAsync(&y_q8_1, y_size_in_bytes, stream); + quantize_q8_1<<>>(inputs, y_q8_1, size_k, kx_padded); + + const int nWraps = 4; + dim3 grid_dim(ceil_div(size_n, nWraps), size_m, 1); + dim3 block_dim(WARP_SIZE, nWraps, 1); + + //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5, + switch (quant_type) { + case 0: // Q8_0 + { + LAUNCH_MOE_GGUF(QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1); + break; + } + case 1: // Q4K + { + LAUNCH_MOE_GGUF(QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1); + break; + } + case 2: // Q2_K + { + LAUNCH_MOE_GGUF(QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1); + break; + } + case 3: // Q3_K + { + LAUNCH_MOE_GGUF(QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1); + break; + } + case 4: // Q5_K + { + LAUNCH_MOE_GGUF(QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1); + break; + } + case 5: // Q6K + { + LAUNCH_MOE_GGUF(QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1); + break; + } + default: + break; + } + cudaFreeAsync(y_q8_1, stream); +} \ No newline at end of file diff --git a/candle-kernels/src/moe/moe_utils.cuh b/candle-kernels/src/moe/moe_utils.cuh new file mode 100644 index 0000000000..596434088c --- /dev/null +++ b/candle-kernels/src/moe/moe_utils.cuh @@ -0,0 +1,188 @@ +#undef __CUDA_FP8_TYPES_EXIST__ +#include +#include +#include +#include +#include + +/** + * @brief Counts the number of tokens assigned to each expert. + * + * @param expert_ids Device pointer to the sorted expert IDs [size_m]. + * @param expert_counts Device pointer to the output counts [num_experts] + * (must be pre-initialized to zero). + * @param size_m Total number of tokens. + */ +static __global__ void count_tokens_per_expert_kernel( + const int32_t* expert_ids, + int32_t* expert_counts, + int size_m) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < size_m) { + int32_t expert_id = expert_ids[i]; + // expert_id is from a sorted list, so we assume it's valid + // (i.e., 0 <= expert_id < num_experts) + atomicAdd(&expert_counts[expert_id], 1); + } +} + +/** + * @brief Calculates expert offsets array on the GPU. + * + * @param d_expert_ids Device pointer to sorted expert IDs [size_m]. + * @param size_m Total number of tokens. + * @param d_expert_offsets Device pointer for output offsets [num_experts + 1]. + * @param num_experts Number of experts. + * @param stream CUDA stream. + */ +static void calculate_expert_offsets( + const int32_t* d_expert_ids, + int size_m, + int32_t* d_expert_counts, + int32_t* d_expert_offsets, + int num_experts, + cudaStream_t stream +) { + // 1. Zero-initialize the counts buffer + cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream); + + // 2. Launch kernel to count tokens per expert + int threads = 256; + int blocks = (size_m + threads - 1) / threads; + count_tokens_per_expert_kernel<<>>( + d_expert_ids, d_expert_counts, size_m + ); + + // 3. Perform prefix sum (scan) + // We will use inclusive_scan on [counts] and store results in [offsets + 1] + // This is a common and efficient pattern. + + // Wrap raw pointers for Thrust + thrust::device_ptr d_counts_ptr(d_expert_counts); + thrust::device_ptr d_offsets_ptr(d_expert_offsets); + + // Run inclusive scan. + // Input: [c0, c1, c2, ...] (size num_experts) + // Output: [c0, c0+c1, c0+c1+c2, ...] (stored at offsets[1]) + thrust::inclusive_scan( + thrust::cuda::par.on(stream), // Execute on the specified stream + d_counts_ptr, // Input start + d_counts_ptr + num_experts, // Input end + d_offsets_ptr + 1 // Output start (shifted by 1) + ); + + // 4. Set the first offset (offsets[0]) to 0 + // This completes the exclusive scan. + cudaMemsetAsync(d_expert_offsets, 0, sizeof(int32_t), stream); +} + + +// This performs an EXCLUSIVE scan: [c0, c1] -> [0, c0, c0+c1] +// Assumptions: num_experts <= 1024 (fits in one block) +static __global__ void expert_prefix_sum_kernel( + const int32_t* __restrict__ counts, + int32_t* __restrict__ offsets, + int num_experts +) { + // Use shared memory for fast scanning + // Size needs to be enough for num_experts + extern __shared__ int32_t temp_storage[]; + + int tid = threadIdx.x; + + // We pad with 0 if tid >= num_experts + int val = (tid < num_experts) ? counts[tid] : 0; + temp_storage[tid] = val; + + __syncthreads(); + + // Hillis-Steele Parallel Scan (Inclusive in shared mem) + for (int offset = 1; offset < blockDim.x; offset <<= 1) { + int temp_val = 0; + if (tid >= offset) { + temp_val = temp_storage[tid - offset]; + } + __syncthreads(); + if (tid >= offset) { + temp_storage[tid] += temp_val; + } + __syncthreads(); + } + + // The result at temp_storage[i] is the inclusive sum of counts[0..i] + // We want offsets[i] = inclusive_sum[i-1] + // We want offsets[0] = 0 + + if (tid < num_experts) { + // Shift right: Offset[i+1] gets the inclusive sum up to i + offsets[tid + 1] = temp_storage[tid]; + + // Handle the first element separately + if (tid == 0) { + offsets[0] = 0; + } + } +} + +static void calculate_expert_offsets_light( + const int32_t* d_expert_ids, + int size_m, + int32_t* d_expert_counts, + int32_t* d_expert_offsets, + int num_experts, + cudaStream_t stream +) { + cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream); + + int threads = 256; + int blocks = (size_m + threads - 1) / threads; + count_tokens_per_expert_kernel<<>>( + d_expert_ids, d_expert_counts, size_m + ); + + // We launch exactly one block with 'num_experts' threads (or next power of 2) + // We need shared memory size = threads * sizeof(int32_t) + int scan_threads = num_experts; + + // Round up scan_threads to next power of 2 if needed, + // or just use a fixed size like 1024 if num_experts is small enough. + if (scan_threads < 32) scan_threads = 32; + else if (scan_threads > 1024) { + // Error: This custom kernel only supports up to 1024 experts + // Handle error or assert here + } + + size_t smem_size = scan_threads * sizeof(int32_t); + + expert_prefix_sum_kernel<<<1, scan_threads, smem_size, stream>>>( + d_expert_counts, + d_expert_offsets, + num_experts + ); +} + +namespace vllm_rs { + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ void from_float(half& dst, float src) { + dst = static_cast(float_to_half(src)); +} + +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +} \ No newline at end of file diff --git a/candle-kernels/src/moe/moe_wmma.cu b/candle-kernels/src/moe/moe_wmma.cu new file mode 100644 index 0000000000..de6a90993b --- /dev/null +++ b/candle-kernels/src/moe/moe_wmma.cu @@ -0,0 +1,283 @@ +/** + * @brief WMMA-based grouped MoE GEMM kernel. + * + * Each block computes a tile of the output corresponding to: + * - One expert segment (group of tokens routed to the same expert) + * - One N-dimension tile (a sub-block of the expert's output features) + * + * The kernel loads input activations and expert weights in tiles using shared memory, + * performs matrix multiplication using Tensor Cores (WMMA), and accumulates results + * into a shared C tile. The final results are written atomically into the global + * output buffer to support multi-expert (top-k > 1) routing where tokens appear in + * multiple experts’ outputs. + * + * Adapted from https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_gemm_wmma.cu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "moe_utils.cuh" +using namespace nvcuda::wmma; + +namespace vllm_rs { + +#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) + +constexpr int WMMA_K = 16; +using VecT = float4; + +// Vectorized load size (float4 = 128 bits = 8 half/bfloat16 values) +constexpr int VEC_SIZE = 8; +constexpr int NUM_VECS = 32; + +// We use 4 Warps (128 threads) per block +constexpr int WARPS_PER_BLOCK = 4; // 4 warps +constexpr int BLOCK_THREADS = 128; // 128 threads + +constexpr int M_BLK = 32; +constexpr int N_BLK = 32; +constexpr int K_BLK = WMMA_K; // 16 + + +/** + * @brief WMMA-based grouped MoE GEMM kernel. + * + * @tparam T Data type: half or nv_bfloat16 + * + * @param input [size_m or size_m/topk, size_k] + * @param weights [num_experts, size_n, size_k] compacted expert weights + * @param sorted_token_ids [size_m] mapping of per-token row indices (sorted by expert) + * @param expert_offsets [num_experts] array of {start, len} tokens indices for each expert + * @param topk_weights [size_m] optional per-token scaling weights (nullptr if unused) + * @param output [size_m, size_n] global output buffer (must be zero-initialized) + * @param num_experts Total number of experts + * @param topk Number of experts each token is routed to + * @param size_m Number of tokens + * @param size_n Output hidden dimension (per expert) + * @param size_k Input hidden dimension +*/ +template +__global__ void moe_gemm_grouped_kernel( + const T* __restrict__ input, // [size_m, size_k] + const T* __restrict__ weights, // [num_experts, size_n, size_k] + const int32_t* __restrict__ sorted_token_ids, // [size_m] + const int32_t* __restrict__ expert_offsets, // [num_experts] + const float* __restrict__ topk_weights, // [size_m] + T* __restrict__ output, // [size_m, size_n] (Zero-initialized) + const int num_experts, const int topk, + const int32_t size_m, + const int32_t size_n, + const int32_t size_k +) { + // Get Segment and N-Tile for this Block + const int expert_id = blockIdx.x; + const int n_tile_idx = blockIdx.y; + if (expert_id < 0 || expert_id >= num_experts) return; + const int segment_start = expert_offsets[expert_id]; + const int segment_end = expert_offsets[expert_id + 1]; + const int num_rows_in_segment = segment_end - segment_start; + + if (num_rows_in_segment == 0) return; + + const int n_base = n_tile_idx * N_BLK; + if (n_base >= size_n) return; + + const T* expert_w = weights + (size_t)expert_id * (size_t)size_n * (size_t)size_k; + + extern __shared__ uint8_t smem_bytes[]; + + // A tile: [M_BLK, K_BLK] (row-major) + T* A_sh = reinterpret_cast(smem_bytes); + // B tile: [N_BLK, K_BLK] (row-major) + T* B_sh = reinterpret_cast(A_sh + M_BLK * K_BLK); + uint8_t* C_ptr = reinterpret_cast(B_sh + N_BLK * K_BLK); + + // align next pointer to float alignment + size_t offset = reinterpret_cast(C_ptr) % alignof(float); + if (offset != 0) { + C_ptr += (alignof(float) - offset); + } + float* C_sh = reinterpret_cast(C_ptr); // shared scratch for final per-block tile writes + + const int threadId = threadIdx.x; + const int warpId = threadId / 32; + const int laneId = threadId % 32; + const int warp_m_idx = warpId / WARPS_N; + const int warp_n_idx = warpId % WARPS_N; + + const int B_ELEMS_PER_BLOCK = N_BLK * K_BLK; + const int VEC_ELEMS_B = B_ELEMS_PER_BLOCK / VEC_SIZE; // 512 / 8 = 64 + const int A_ELEMS_PER_BLOCK = M_BLK * K_BLK; + const int VEC_ELEMS_A = A_ELEMS_PER_BLOCK / VEC_SIZE; // 512 / 8 = 64 + VecT zero_vec; + zero_vec.x = zero_vec.y = zero_vec.z = zero_vec.w = 0.0f; + + for (int m_base = 0; m_base < num_rows_in_segment; m_base += M_BLK) { + // We'll accumulate full-K results in per-warp fragments (initialized here) + fragment c_frag; + fill_fragment(c_frag, 0.0f); + + // For every k_block we will load B_sh and A_sh for this m_base subsequently + for (int k_base = 0; k_base < size_k; k_base += K_BLK) { + // Load B Tile (Weights) into B_sh + for (int i = threadId; i < VEC_ELEMS_B; i += BLOCK_THREADS) { + int idx = i * VEC_SIZE; // element index (0..511) + int n_local = idx / K_BLK; + int k_local = idx % K_BLK; + + int n_global = n_base + n_local; + int k_global = k_base + k_local; + + // this should be always satisfied since k dim aligned to 8 + if (n_global < size_n && k_global < size_k) { + *reinterpret_cast(&B_sh[n_local * K_BLK + k_local]) = *reinterpret_cast( + &expert_w[(size_t)n_global * size_k + k_global] + ); + } else { + *reinterpret_cast(&B_sh[n_local * K_BLK + k_local]) = zero_vec; + } + } + + // Load A Tile (Inputs) into A_sh for this m_base and this k_base + for (int i = threadId; i < VEC_ELEMS_A; i += BLOCK_THREADS) { + int idx = i * VEC_SIZE; // element index + int m_local = idx / K_BLK; + int k_local = idx % K_BLK; + + int m_seg = m_base + m_local; // row index within segment + int k_global = k_base + k_local; + + if (m_seg < num_rows_in_segment && k_global < size_k) { + int token_pair_index = segment_start + m_seg; + int token_index = sorted_token_ids[token_pair_index]; + int input_index = token_index / (topk_weights? 1: topk); + *reinterpret_cast(&A_sh[m_local * K_BLK + k_local]) = *reinterpret_cast( + &input[(size_t)input_index * size_k + k_global] + ); + } else { + // in case m dim in this segment not aligned to 8 + *reinterpret_cast(&A_sh[m_local * K_BLK + k_local]) = zero_vec; + } + } + + __syncthreads(); + + // Compute (Warp-level) : update c_frag for this k_block + fragment a_frag; + fragment b_frag; + + // Point this warp to its tile in shared memory + const T* A_sh_ptr = A_sh + (warp_m_idx * WMMA_M * K_BLK); + const T* B_sh_ptr = B_sh + (warp_n_idx * WMMA_N * K_BLK); + + load_matrix_sync(a_frag, A_sh_ptr, K_BLK); + load_matrix_sync(b_frag, B_sh_ptr, K_BLK); + + // Accumulate into c_frag (which persists across k_base iterations) + mma_sync(c_frag, a_frag, b_frag, c_frag); + } // end k_base loop (we have a fully-accumulated c_frag for this m_base tile) + + // Store the accumulated c_frag to C_sh (shared) once per warp + // Point this warp to its 16x16 tile *within* the 32x32 C_sh + float* C_sh_ptr = C_sh + (warp_m_idx * WMMA_M * N_BLK) + (warp_n_idx * WMMA_N); + // store the full accumulated 16x16 tile (note ld = N_BLK, result in row-major in C_sh) + store_matrix_sync(C_sh_ptr, c_frag, N_BLK, mem_row_major); + + __syncthreads(); + + // Cooperative Store from C_sh to Global + // 128 threads write [M_BLK, N_BLK] = [32, 32] = 1024 elements + const int C_ELEMS_PER_BLOCK = M_BLK * N_BLK; + for (int i = threadId; i < C_ELEMS_PER_BLOCK; i += BLOCK_THREADS) { + int m_local_c = i / N_BLK; // row in C_sh (0..31) + int n_local_c = i % N_BLK; // col in C_sh (0..31) + + int m_seg = m_base + m_local_c; // row index within segment + int n_global = n_base + n_local_c; // col index in output + + if (m_seg < num_rows_in_segment && n_global < size_n) { + int token_pair_index = segment_start + m_seg; + if (token_pair_index < size_m) { + int token_index = sorted_token_ids[token_pair_index]; + float val = C_sh[m_local_c * N_BLK + n_local_c]; + if (topk_weights) { + val *= topk_weights[token_index]; + } + from_float(output[(size_t)token_index * size_n + n_global], val); + } + } + } + } // end m_base loop +} + +} + +#define LAUNCH_MOE_WMMA(DTYPE, WMMA_M, WMMA_N, WARPS_N)\ + vllm_rs::moe_gemm_grouped_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids,\ + expert_offsets,\ + topk_weights,\ + reinterpret_cast(output),\ + num_experts, topk,\ + size_m, size_n, size_k \ + );\ + +extern "C" void moe_gemm_wmma( + const void* input, // [size_m, size_k] + const void* weights, // [num_experts, size_n, size_k] + const int32_t* sorted_token_ids, // [size_m] (Device) + const int32_t* expert_ids, // [size_m * topk] + const float* topk_weights, // [size_m] (Device, can be nullptr) + void* output, // [size_m, size_n] + int32_t* expert_counts, // prealloc [num_experts] + int32_t* expert_offsets, // prealloc [num_experts + 1] + int num_experts, + int topk, + int size_m, + int size_n, + int size_k, + int data_type, // 0 = half, 1 = bfloat16 + bool is_prefill, + cudaStream_t stream +) { + if (is_prefill) { + calculate_expert_offsets(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream); + } else { + calculate_expert_offsets_light(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream); + } + + int grid_n = CEILDIV(size_n, vllm_rs::N_BLK); + dim3 grid(num_experts, grid_n, 1); + dim3 block(vllm_rs::BLOCK_THREADS, 1, 1); + + // Shared memory: A_sh[M_BLK, K_BLK] + B_sh[N_BLK, K_BLK] + size_t A_sh_bytes = vllm_rs::M_BLK * vllm_rs::K_BLK * 2; // (32*16 * 2) = 1024 + size_t B_sh_bytes = vllm_rs::N_BLK * vllm_rs::K_BLK * 2; // (32*16 * 2) = 1024 + size_t C_sh_bytes = vllm_rs::M_BLK * vllm_rs::N_BLK * sizeof(float); + size_t AB_bytes = A_sh_bytes + B_sh_bytes; + size_t pad = (16 - (AB_bytes % 16)) % 16; + size_t smem_bytes = AB_bytes + pad + C_sh_bytes; // ~6KB total needed + + if (data_type == 0) { // half + if (is_prefill) { + LAUNCH_MOE_WMMA(half, 16, 16, 2) + } else { + // we use smaller M_tile and larger N_tile for decoding + LAUNCH_MOE_WMMA(half, 8, 32, 1) + } + } else if (data_type == 1) { // bfloat16 + if (is_prefill) { + LAUNCH_MOE_WMMA(nv_bfloat16, 16, 16, 2) + } else { + LAUNCH_MOE_WMMA(nv_bfloat16, 8, 32, 1) + } + } +} \ No newline at end of file diff --git a/candle-kernels/src/moe/moe_wmma_gguf.cu b/candle-kernels/src/moe/moe_wmma_gguf.cu new file mode 100644 index 0000000000..0d3701ee82 --- /dev/null +++ b/candle-kernels/src/moe/moe_wmma_gguf.cu @@ -0,0 +1,422 @@ +/** + * @brief CUDA kernel for Mixture-of-Experts (MoE) GEMM with GGUF quantized weights and Tensor Core. + * + * This kernel performs batched GEMM where the weight matrix is stored in GGUF + * quantized format (uint8_t blocks). It supports top-k expert selection and + * segmented expert layouts. Uses shared memory tiles and WMMA (tensor cores) + * for efficient computation. + * + * Adapted from: https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_wmma_gguf.cu + */ +#include "gguf.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "moe_utils.cuh" +using namespace nvcuda::wmma; + +// Constants from original kernel +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; // This is fixed by the hardware instruction +using VecT = float4; + +constexpr int VEC_SIZE = 8; +constexpr int WARPS_M = 2; +constexpr int WARPS_N = 2; +constexpr int WARPS_PER_BLOCK = WARPS_M * WARPS_N; // 4 warps + +constexpr int M_BLK = WARPS_M * WMMA_M; // 32 +constexpr int N_BLK = WARPS_N * WMMA_N; // 32 + +// Helper for ceiling division +#define CEILDIV(A, B) (((A) + (B)-1) / (B)) + +// --- GGUF Dequantization Function (Warp-level) --- +/** + * @brief Dequantizes a single GGUF block using one warp (32 threads). + * + * @tparam T Output type (half or nv_bfloat16) + * @param dequant_out Pointer to output in shared mem [qk] + * @param quant_in Pointer to input GGUF block in shared mem + * @param type GGUF type + * @param qk Quantization group size (32 or 256) + * @param laneId threadIdx.x % 32 + */ +template +__forceinline__ __device__ void dequantize_block_warp( + T* dequant_out, + const uint8_t* quant_in, + int gguf_dtype //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5, +) { + using namespace nvcuda; + switch (gguf_dtype) { + case 0: { // qk = 32, q8_0 + // Block: half d (2B), int8_t qs[32] (32B) + int laneId = threadIdx.x; + const half* d_ptr = (const half*)quant_in; + const int8_t* qs = (const int8_t*)(quant_in + 2); + + // Lane 0 loads scale and broadcasts to all other lanes + half d_val = (laneId == 0) ? *d_ptr : (half)0.0f; + d_val = __shfl_sync(0xFFFFFFFF, d_val, 0); + float d_f = __half2float(d_val); + + // 32 lanes dequantize 32 values + if (laneId < QK8_0) { // qk should be 32 + dequant_out[laneId] = T( (float)qs[laneId] * d_f ); + } + break; + } + case 1: { // q4k, 32 lanes + dequantize_block_q4_K(quant_in, dequant_out); + break; + } + case 2: { // q2k, 64 lanes + dequantize_block_q2_K(quant_in, dequant_out); + break; + } + case 3: { // q3k, 64 lanes + dequantize_block_q3_K(quant_in, dequant_out); + break; + } + case 4: { // q5k, 64 lanes + dequantize_block_q5_K(quant_in, dequant_out); + break; + } + case 5: { // q6k, 64 lanes + dequantize_block_q6_K(quant_in, dequant_out); + break; + } + default: + break; + } +} + +/* +* Template Parameters: + * @tparam T Type of input/output (float, half, etc.) + * @tparam qk Quantization block size (e.g., 32) + * @tparam block_q_t Type representing a single GGUF block (e.g., block_q8_0) + * @tparam wrap_size Warp size used for thread tiling (usually 32) + * + * Kernel Parameters: + * @param input Input matrix [size_m, size_k] + * @param weights GGUF quantized weights buffer (uint8_t blocks) + * @param sorted_token_ids Array of sorted token indices for MoE routing + * @param expert_offsets [num_experts] array of {start, len} tokens indices for each expert + * @param topk_weights Top-k MoE weights per token (optional) + * @param output Output matrix [size_m, size_n] + * @param num_experts Number of experts in the MoE + * @param topk Number of top experts selected per token + * @param size_m Number of input rows / tokens + * @param size_n Output feature dimension + * @param size_k Input feature dimension + * @param gguf_dtype GGUF quantization type ID (e.g., Q8_0) +*/ +template +__global__ void moe_gemm_gguf_prefill_kernel( + const T* __restrict__ input, + const uint8_t* __restrict__ weights, // Now uint8_t* + const int32_t* __restrict__ sorted_token_ids, + const int32_t* __restrict__ expert_offsets, + const float* __restrict__ topk_weights, + float* __restrict__ output, + const int num_experts, const int topk, + const int32_t size_m, + const int32_t size_n, + const int32_t size_k, + const int gguf_dtype +) { + const int expert_id = blockIdx.x; + const int n_tile_idx = blockIdx.y; + + if (expert_id < 0 || expert_id >= num_experts) return; + const int segment_start = expert_offsets[expert_id]; + const int segment_end = expert_offsets[expert_id + 1]; + const int num_rows_in_segment = segment_end - segment_start; + + if (num_rows_in_segment == 0) return; + constexpr int BLOCK_THREADS = WARPS_PER_BLOCK * wrap_size; // 128 threads + + const int n_base = n_tile_idx * N_BLK; + if (n_base >= size_n) return; + + const size_t block_size_bytes = sizeof(block_q_t); + const size_t expert_w_row_stride_bytes = (size_k / qk) * block_size_bytes; + const uint8_t* expert_w = weights + (size_t)expert_id * size_n * expert_w_row_stride_bytes; + + extern __shared__ uint8_t smem_bytes[]; + + // 1. A tile: [M_BLK, qk] (dequantized) + T* A_sh = reinterpret_cast(smem_bytes); + size_t A_sh_bytes = (size_t)M_BLK * qk * sizeof(T); + + // 2. B tile: [N_BLK, qk] (dequantized) + uint8_t* B_sh_ptr = smem_bytes + A_sh_bytes; + size_t B_sh_bytes = (size_t)N_BLK * qk * sizeof(T); + + // 3. B quantized tile: [N_BLK * block_size_bytes] (raw GGUF) + uint8_t* B_quant_sh_ptr = B_sh_ptr + B_sh_bytes; + size_t B_quant_sh_bytes = (size_t)N_BLK * block_size_bytes; + + // 4. C tile: [M_BLK, N_BLK] (float accumulator) + uint8_t* C_sh_ptr = B_quant_sh_ptr + B_quant_sh_bytes; + size_t C_sh_offset = reinterpret_cast(C_sh_ptr) % alignof(float); + if (C_sh_offset != 0) C_sh_ptr += (alignof(float) - C_sh_offset); + + // Final aligned shared memory pointers + T* B_sh = reinterpret_cast(B_sh_ptr); + uint8_t* B_quant_sh = reinterpret_cast(B_quant_sh_ptr); + float* C_sh = reinterpret_cast(C_sh_ptr); + + const int laneId = threadIdx.x; + const int warpId = threadIdx.y; + const int threadId = warpId * wrap_size + laneId; + const int warp_m_idx = warpId / WARPS_N; + const int warp_n_idx = warpId % WARPS_N; + + const size_t A_ELEMS_PER_BLOCK = (size_t)M_BLK * qk; + const size_t VEC_ELEMS_A = A_ELEMS_PER_BLOCK / VEC_SIZE; + VecT zero_vec; + zero_vec.x = zero_vec.y = zero_vec.z = zero_vec.w = 0.0f; + + for (int m_base = 0; m_base < num_rows_in_segment; m_base += M_BLK) { + + // Per-warp accumulator fragment + fragment c_frag; + fill_fragment(c_frag, 0.0f); + + // K-Loop: Strides by GGUF block size `qk` + for (int k_base = 0; k_base < size_k; k_base += qk) { + + // Load A Tile (Inputs) into A_sh + #pragma unroll + for (size_t i = threadId; i < VEC_ELEMS_A; i += BLOCK_THREADS) { + size_t idx = i * VEC_SIZE; // element index + size_t m_local = idx / qk; + size_t k_local = idx % qk; + + int m_seg = m_base + m_local; + int k_global = k_base + k_local; + + if (m_seg < num_rows_in_segment && k_global < size_k) { + int token_pair_index = segment_start + m_seg; + int token_index = sorted_token_ids[token_pair_index]; + int input_index = token_index / (topk_weights? 1: topk); + *reinterpret_cast(&A_sh[m_local * qk + k_local]) = *reinterpret_cast( + &input[(size_t)input_index * size_k + k_global] + ); + } else { + *reinterpret_cast(&A_sh[m_local * qk + k_local]) = zero_vec; + } + } + + // Load B Tile (Quantized) into B_quant_sh + const size_t k_base_offset_bytes = (k_base / qk) * block_size_bytes; + constexpr int ROWS_PER_WARP = N_BLK / WARPS_PER_BLOCK; + + #pragma unroll + for (int row = 0; row < ROWS_PER_WARP; ++row) { + int n_local = warpId * ROWS_PER_WARP + row; + int n_global = n_base + n_local; + if (n_local < N_BLK && n_global < size_n) { + block_q_t* dest_ptr = reinterpret_cast(B_quant_sh + n_local * block_size_bytes); + const block_q_t* src_ptr = reinterpret_cast(expert_w + (size_t)n_global * expert_w_row_stride_bytes + k_base_offset_bytes); + *dest_ptr = *src_ptr; + } + } + + __syncthreads(); + + // Dequantize B from B_quant_sh to B_sh + #pragma unroll + for (int row = 0; row < ROWS_PER_WARP; ++row) { + int n_local = warpId * ROWS_PER_WARP + row; + int n_global = n_base + n_local; + if (n_local < N_BLK && n_global < size_n) { + const uint8_t* quant_ptr = B_quant_sh + n_local * block_size_bytes; + T* dequant_ptr = B_sh + n_local * qk; // Stride by qk + // Dequantize one block using this warp + dequantize_block_warp(dequant_ptr, quant_ptr, gguf_dtype); + } + } + + __syncthreads(); + + // Inner WMMA Loop + // A_sh and B_sh are now dequantized and in shared mem + // We loop over the K-dim (now `qk`) using the hardware `WMMA_K` + #pragma unroll + for (int k_tile = 0; k_tile < qk; k_tile += WMMA_K) { + fragment a_frag; + fragment b_frag; + + // Point to the correct 16x16 tile inside the [M_BLK, qk] / [N_BLK, qk] buffers + const T* A_sh_ptr = A_sh + (warp_m_idx * WMMA_M * qk) + k_tile; + const T* B_sh_ptr = B_sh + (warp_n_idx * WMMA_N * qk) + k_tile; + + load_matrix_sync(a_frag, A_sh_ptr, qk); // Stride is qk + load_matrix_sync(b_frag, B_sh_ptr, qk); // Stride is qk + + mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } // end k_base loop + + // Store C_frag to C_sh + float* C_sh_ptr_warp = C_sh + (warp_m_idx * WMMA_M * N_BLK) + (warp_n_idx * WMMA_N); + store_matrix_sync(C_sh_ptr_warp, c_frag, N_BLK, mem_row_major); + __syncthreads(); + + // Cooperative Store to Global + const int C_ELEMS_PER_BLOCK = M_BLK * N_BLK; + #pragma unroll + for (int i = threadId; i < C_ELEMS_PER_BLOCK; i += BLOCK_THREADS) { + int m_local_c = i / N_BLK; + int n_local_c = i % N_BLK; + int m_seg = m_base + m_local_c; + int n_global = n_base + n_local_c; + + if (m_seg < num_rows_in_segment && n_global < size_n) { + int token_pair_index = segment_start + m_seg; + if (token_pair_index < size_m) { + int token_index = sorted_token_ids[token_pair_index]; + float val = C_sh[m_local_c * N_BLK + n_local_c]; + if (topk_weights) { + val *= topk_weights[token_index]; + } + output[(size_t)token_index * size_n + n_global] = val; + } + } + } + } // end m_base loop +} + +#define LAUNCH_MOE_GGUF_PREFILL(DTYPE) \ + if (gguf_type == 0) {\ + dim3 block(32, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } else if (gguf_type == 1) {\ + dim3 block(32, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } else if (gguf_type == 2) {\ + dim3 block(64, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } else if (gguf_type == 3) {\ + dim3 block(64, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } else if (gguf_type == 4) { \ + dim3 block(64, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } else if (gguf_type == 5) { \ + dim3 block(64, WARPS_PER_BLOCK, 1);\ + moe_gemm_gguf_prefill_kernel<<>>(\ + reinterpret_cast(input),\ + reinterpret_cast(weights),\ + sorted_token_ids, expert_offsets, topk_weights,\ + output, num_experts, topk, size_m, size_n, size_k, gguf_type\ + );\ + } + + +extern "C" void moe_gemm_gguf_prefill( + const void* input, + const uint8_t* weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const float* topk_weights, + float* output, + int num_experts, + int topk, + int size_m, + int size_n, + int size_k, + int input_dtype, // 0 = half, 1 = bfloat16 + int gguf_type, //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5, + cudaStream_t stream +) { + int32_t* expert_counts; + cudaMallocAsync(&expert_counts, num_experts * sizeof(int32_t), stream); + + int32_t* expert_offsets; + cudaMallocAsync(&expert_offsets, (num_experts + 1) * sizeof(int32_t), stream); + calculate_expert_offsets(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream); + + int grid_n = CEILDIV(size_n, N_BLK); + dim3 grid(num_experts, grid_n, 1); + + size_t qk = QK_K; + size_t block_size_bytes = sizeof(block_q6_K); + if (gguf_type == 0) { //Q8_0: 0, + block_size_bytes = sizeof(block_q8_0); + qk = QK8_0; + } else if (gguf_type == 1) {// Q4K: 1, + block_size_bytes = sizeof(block_q4_K); + } else if (gguf_type == 2) {// Q2K: 2, + block_size_bytes = sizeof(block_q2_K); + } else if (gguf_type == 3) {//Q3K: 3, + block_size_bytes = sizeof(block_q3_K); + } else if (gguf_type == 4) {//Q5K: 4, + block_size_bytes = sizeof(block_q5_K); + } + + // 1. A tile: [M_BLK, qk] (dequantized) + size_t A_sh_bytes = (size_t)M_BLK * qk * 2; // 2 for half/bfloat16 + + // 2. B tile: [N_BLK, qk] (dequantized) + size_t B_sh_bytes = (size_t)N_BLK * qk * 2; + + // 3. B quantized tile: [N_BLK * block_size_bytes] + size_t B_quant_sh_bytes = (size_t)N_BLK * block_size_bytes; + + // 4. C tile: [M_BLK, N_BLK] (float accumulator) + size_t C_sh_bytes = (size_t)M_BLK * N_BLK * sizeof(float); + + // Add up, with padding for C + size_t smem_bytes = A_sh_bytes + B_sh_bytes + B_quant_sh_bytes; + size_t C_sh_offset = smem_bytes % alignof(float); + if (C_sh_offset != 0) smem_bytes += (alignof(float) - C_sh_offset); + smem_bytes += C_sh_bytes; + + if (input_dtype == 0) { + LAUNCH_MOE_GGUF_PREFILL(half); + } else { +#ifndef NO_BF16_KERNEL + LAUNCH_MOE_GGUF_PREFILL(nv_bfloat16); +#endif + } + cudaFreeAsync(expert_counts, stream); + cudaFreeAsync(expert_offsets, stream); +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index c7a76fbd7a..febd73a2d6 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -28,6 +28,7 @@ pub mod kv_cache; pub mod layer_norm; pub mod linear; pub mod loss; +pub mod moe; pub mod ops; pub mod optim; pub mod rnn; diff --git a/candle-nn/src/moe.rs b/candle-nn/src/moe.rs new file mode 100644 index 0000000000..4f62af8f66 --- /dev/null +++ b/candle-nn/src/moe.rs @@ -0,0 +1,349 @@ +// Adapted from https://github.com/guoqingbao/attention.rs/blob/main/src/moe.rs +#[cfg(feature = "cuda")] +use candle::cuda_backend::kernels::ffi; +#[allow(unused_imports)] +use candle::quantized::{self, QTensor}; +use candle::{Result, Tensor}; + +#[cfg(feature = "cuda")] +pub fn moe_gemm( + input: &Tensor, + weights: &Tensor, + topk_weights: &Option, + sorted_token_ids: &Tensor, + experts_ids: &Tensor, + topk: usize, + is_prefill: bool, +) -> Result { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle::DType; + use half::{bf16, f16}; + + fn cuda_fwd< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + input: &Tensor, + weights: &Tensor, + topk_weights: &Option, + sorted_token_ids: &Tensor, + experts_ids: &Tensor, + topk: usize, + is_prefill: bool, + ) -> Result { + let (mut size_m, size_k1) = input.dims2()?; + if topk_weights.is_none() { + size_m *= topk; + } + let (num_experts, size_n, size_k) = weights.dims3()?; + assert!( + size_k == size_k1, + "input {:?} and weight {:?} last dim mismatch!", + size_k1, + size_k + ); + let dev = input.device().as_cuda_device()?; + let data_type = match input.dtype() { + DType::F16 => 0, + DType::BF16 => 1, + _ => { + candle::bail!("moe_gemm_wmma only accept f16/bf16 inputs!") + } + }; + + let (input, _) = input.storage_and_layout(); + let input = match &*input { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("input must be a cuda tensor"), + }; + + let (weights, _) = weights.storage_and_layout(); + let weights = match &*weights { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("weight must be a cuda tensor"), + }; + + let (sorted_token_ids, _) = sorted_token_ids.storage_and_layout(); + let sorted_token_ids = match &*sorted_token_ids { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("sorted_token_ids must be a cuda tensor"), + }; + + let (experts_ids, _) = experts_ids.storage_and_layout(); + let experts_ids = match &*experts_ids { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("experts_ids must be a cuda tensor"), + }; + + let topk_weights_ptr = if let Some(topk_weights) = &topk_weights { + let (topk_weights, _) = topk_weights.storage_and_layout(); + let topk_weights = match &*topk_weights { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("topk_weights must be a cuda tensor"), + }; + let weights_ptr = topk_weights.device_ptr(topk_weights.stream()).0 as *const f32; + weights_ptr + } else { + std::ptr::null() as *const f32 + }; + + let output = unsafe { dev.alloc::(size_m * size_n) }?; + let expert_counts = unsafe { dev.alloc::(num_experts) }?; + let expert_offsets = unsafe { dev.alloc::(num_experts + 1) }?; + + let stream = dev.cuda_stream().cu_stream() as i64; + use core::ffi::c_void; + + unsafe { + ffi::moe_gemm_wmma( + input.device_ptr(input.stream()).0 as *const c_void, // [size_m, size_k] + weights.device_ptr(weights.stream()).0 as *const c_void, // [num_experts, size_n, size_k] + sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32, + experts_ids.device_ptr(experts_ids.stream()).0 as *const i32, + topk_weights_ptr, + output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n] + expert_counts.device_ptr(expert_counts.stream()).0 as *mut i32, // pre-allocated buffer [num_experts] + expert_offsets.device_ptr(expert_offsets.stream()).0 as *mut i32, // pre-allocated buffer [num_experts + 1] + num_experts as i32, + topk as i32, + size_m as i32, + size_n as i32, + size_k as i32, + data_type as i32, // 0=float16, 1=bf16 (for input/output) + is_prefill, + stream as i64, + ); + } + + use candle::op::BackpropOp; + let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone()); + let output = Tensor::from_storage( + candle::Storage::Cuda(output), + (size_m, size_n), + BackpropOp::none(), + false, + ); + + Ok(output) + } + + match input.dtype() { + DType::F16 => cuda_fwd::( + input, + weights, + topk_weights, + sorted_token_ids, + experts_ids, + topk, + is_prefill, + ), + DType::BF16 => cuda_fwd::( + input, + weights, + topk_weights, + sorted_token_ids, + experts_ids, + topk, + is_prefill, + ), + _ => { + candle::bail!("moe_gemm only accept f16/bf16 inputs!") + } + } +} + +#[cfg(not(feature = "cuda"))] +pub fn moe_gemm( + _: &Tensor, + _: &Tensor, + _: &Option, + _: &Tensor, + _: &Tensor, + _: usize, + _: bool, +) -> Result { + candle::bail!("moe_gemm is not implemented on this platform!") +} + +#[cfg(feature = "cuda")] +pub fn moe_gemm_gguf( + input: &Tensor, + weights: &QTensor, + topk_weights: &Option, + sorted_token_ids: &Tensor, + experts_ids: &Tensor, + topk: usize, + is_prefill: bool, + dtype: candle::DType, +) -> Result { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle::quantized::GgmlDType; + use candle::DType; + use half::{bf16, f16}; + + fn cuda_fwd( + input: &Tensor, + weights: &QTensor, + topk_weights: &Option, + sorted_token_ids: &Tensor, + experts_ids: &Tensor, + topk: usize, + is_prefill: bool, + dtype: DType, + ) -> Result { + let (mut size_m, size_k) = input.dims2()?; + if topk_weights.is_none() { + size_m *= topk; + } + let (num_experts, size_n, size_k1) = weights.shape().dims3()?; + assert!( + size_k == size_k1, + "input {:?} and weight {:?} last dim mismatch!", + size_k, + size_k1, + ); + let dev = input.device().as_cuda_device()?; + + // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 + let gguf_dtype = match weights.dtype() { + GgmlDType::Q8_0 => 0, + GgmlDType::Q4K => 1, + GgmlDType::Q2K => 2, + GgmlDType::Q3K => 3, + GgmlDType::Q5K => 4, + GgmlDType::Q6K => 5, + _ => { + candle::bail!( + "moe_gemm_gguf `ISQ` only accept q2k, q3k, q4k, q5k, q6k or q8_0 weights!" + ) + } + }; + + let weight_ptr = weights.device_ptr()?; + + let topk_weights_ptr = if let Some(topk_weights) = &topk_weights { + let (topk_weights, _) = topk_weights.storage_and_layout(); + let topk_weights = match &*topk_weights { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("topk_weights must be a cuda tensor"), + }; + let w_ptr = topk_weights.device_ptr(topk_weights.stream()).0 as *const f32; + w_ptr + } else { + std::ptr::null() as *const f32 + }; + + let (sorted_token_ids, _) = sorted_token_ids.storage_and_layout(); + let sorted_token_ids = match &*sorted_token_ids { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("sorted_token_ids must be a cuda tensor"), + }; + let (experts_ids, _) = experts_ids.storage_and_layout(); + let experts_ids = match &*experts_ids { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("experts_ids must be a cuda tensor"), + }; + + let output = unsafe { dev.alloc::(size_m * size_n) }?; + let stream = dev.cuda_stream().cu_stream() as i64; + use candle::op::BackpropOp; + use core::ffi::c_void; + + assert!(size_k % 8 == 0, "size_k must divisible by 8"); + unsafe { + if is_prefill { + let input = input.to_dtype(dtype)?; + let (input, _) = input.storage_and_layout(); + let (input_ptr, input_dtype) = match &*input { + candle::Storage::Cuda(c) => { + if dtype == DType::F16 { + let c = c.as_cuda_slice::()?; + (c.device_ptr(c.stream()).0 as *const c_void, 0) + } else { + let c = c.as_cuda_slice::()?; + (c.device_ptr(c.stream()).0 as *const c_void, 1) + } + } + _ => candle::bail!("input must be a cuda tensor"), + }; + ffi::moe_gemm_gguf_prefill( + input_ptr, // [size_m or size_m/topk, size_k] + weight_ptr as *const u8, // [num_experts, size_n, size_k] + sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32, + experts_ids.device_ptr(experts_ids.stream()).0 as *const i32, + topk_weights_ptr, + output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n] + num_experts as i32, + topk as i32, + size_m as i32, + size_n as i32, + size_k as i32, + input_dtype as i32, + gguf_dtype as i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weight) + stream as i64, + ); + } else { + let (input, _) = input.storage_and_layout(); + let input = match &*input { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("input must be a cuda tensor"), + }; + + ffi::moe_gemm_gguf( + input.device_ptr(input.stream()).0 as *const f32, // [size_m or size_m/topk, size_k] + weight_ptr as *const c_void, // [num_experts, size_n, size_k] + sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32, + experts_ids.device_ptr(experts_ids.stream()).0 as *const i32, + topk_weights_ptr, + output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n] + num_experts as i32, + topk as i32, + size_m as i32, + size_n as i32, + size_k as i32, + gguf_dtype as i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weight) + stream as i64, + ); + } + } + + let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone()); + let output = Tensor::from_storage( + candle::Storage::Cuda(output), + (size_m, size_n), + BackpropOp::none(), + false, + ); + + Ok(output) + } + + match input.dtype() { + DType::F32 => cuda_fwd( + input, + weights, + topk_weights, + sorted_token_ids, + experts_ids, + topk, + is_prefill, + dtype, + ), + _ => { + candle::bail!("moe_gemm_gguf only accept f16/bf16 inputs!") + } + } +} + +#[cfg(not(feature = "cuda"))] +pub fn moe_gemm_gguf( + _: &Tensor, + _: &QTensor, + _: &Option, + _: &Tensor, + _: &Tensor, + _: usize, + _: bool, + _: candle::DType, +) -> Result { + candle::bail!("moe_gemm_gguf is not implemented on this platform!") +} diff --git a/candle-transformers/src/fused_moe.rs b/candle-transformers/src/fused_moe.rs new file mode 100644 index 0000000000..91eb5f217b --- /dev/null +++ b/candle-transformers/src/fused_moe.rs @@ -0,0 +1,302 @@ +// Adapted from: https://github.com/guoqingbao/vllm.rs/blob/main/src/models/layers/moe.rs +use candle::Module; +use candle::{quantized::QTensor, DType, Result, Tensor, D}; +use candle_nn::{linear_no_bias, moe, Activation, Linear, VarBuilder}; +use std::sync::Arc; + +pub struct MoeCfg { + pub hidden_size: usize, + pub num_experts: usize, + pub num_experts_per_tok: usize, + pub moe_intermediate_size: usize, + pub norm_topk_prob: bool, + pub act: Activation, + pub decoder_sparse_step: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub struct FusedMoe { + gate: Linear, + gate_up_w: Tensor, + down_w: Tensor, + w_size_n: usize, + act: Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, + // world_size: usize, + dtype: DType, +} + +impl FusedMoe { + pub fn new(cfg: &MoeCfg, vb: VarBuilder, dtype: DType) -> Result { + let num_experts = cfg.num_experts; + + let gate = linear_no_bias(cfg.hidden_size, num_experts, vb.pp("gate"))?; + + let experts_vb = vb.pp("experts"); + let mut gate_up_experts = Vec::with_capacity(num_experts); + let mut down_experts = Vec::with_capacity(num_experts); + + //pack experts + for i in 0..num_experts { + let experts_vb = experts_vb.pp(format!("{}", i).as_str()); + + let (gate_up_expert, down_expert) = { + // n x k format + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let gate_expert = experts_vb.pp("gate_proj").get_with_hints( + (cfg.moe_intermediate_size, cfg.hidden_size), + "weight", + init_ws, + )?; + let up_expert = experts_vb.pp("up_proj").get_with_hints( + (cfg.moe_intermediate_size, cfg.hidden_size), + "weight", + init_ws, + )?; + let down_expert = experts_vb.pp("down_proj").get_with_hints( + (cfg.hidden_size, cfg.moe_intermediate_size), + "weight", + init_ws, + )?; + //pack gate_proj and up_proj + let gate_up_expert = Tensor::cat(&[&gate_expert, &up_expert], 0)?; + + (gate_up_expert, down_expert) + }; + + gate_up_experts.push(gate_up_expert); + down_experts.push(down_expert); + } + + let gate_up_w = Tensor::stack(&gate_up_experts, 0)?; + let down_w = Tensor::stack(&down_experts, 0)?; + // let world_size = comm.world_size(); + let w_size_n = gate_up_w.dim(1)? / 2; + + Ok(Self { + gate, + gate_up_w, + down_w, + w_size_n, + act: cfg.act, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + // world_size, + dtype, + }) + } + + pub fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result { + let (batch, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let (num_tokens, hidden_dim) = xs.dims2()?; + + let router_logits = self.gate.forward(&xs)?; + + let routing_weights = + candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?; + + let topk_ids = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + + let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?; + + if self.norm_topk_prob { + topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?; + } + + let (expert_ids, sorted_token_ids) = if is_prefill { + // For long-context (32K+), need to use custom sort kernel + // #[cfg(feature = "cuda")] + // { + // use attention_rs::sort::ArgSortOp; + // topk_ids.flatten_all()?.sort(true)? + // } + // #[cfg(not(feature = "cuda"))] + topk_ids.flatten_all()?.sort_last_dim(true)? + } else { + topk_ids.flatten_all()?.sort_last_dim(true)? + }; + + //out (M, top_k, N) + let gate_up = moe::moe_gemm( + &xs, + &self.gate_up_w, + &None, + &sorted_token_ids, + &expert_ids, + self.num_experts_per_tok, + is_prefill, + )?; + + let gate = gate_up + .narrow(candle::D::Minus1, 0, self.w_size_n)? + .contiguous()?; + let up = gate_up + .narrow(candle::D::Minus1, self.w_size_n, self.w_size_n)? + .contiguous()?; + + //(M * top_k, N // 2) + let down_inputs = (up * gate.apply(&self.act)?)?.reshape(((), self.w_size_n))?; + + //view(M, top_k, K) -> sum -> (M, K) + let ys = moe::moe_gemm( + &down_inputs, + &self.down_w, + &Some(topk_weights), + &sorted_token_ids, + &expert_ids, + self.num_experts_per_tok, + is_prefill, + )? + .reshape((num_tokens, (), hidden_dim))? + .sum(D::Minus2)?; + + Ok(ys.reshape((batch, seq_len, hidden_dim))?) + } +} + +pub struct FusedMoeGGUF { + pub gate: Linear, + pub gate_experts: Arc, + pub up_experts: Arc, + pub down_experts: Arc, + pub act: Activation, + pub norm_topk_prob: bool, + pub num_experts_per_tok: usize, + // all_reduce: AllReduce, + // world_size: usize, + pub dtype: DType, +} + +impl FusedMoeGGUF { + pub fn new( + cfg: &MoeCfg, + vb: crate::quantized_var_builder::VarBuilder, + dtype: DType, + ) -> Result { + let num_experts = cfg.num_experts; + let gate_ws = vb + .pp("ffn_gate_inp") + .get((num_experts, cfg.hidden_size), "weight")? + .dequantize_f16(&vb.device())? + .to_dtype(DType::F32)?; + + let gate = Linear::new(gate_ws, None); + + let (gate_experts, up_experts, down_experts) = { + ( + vb.pp("ffn_gate_exps").get( + (num_experts, cfg.moe_intermediate_size, cfg.hidden_size), + "weight", + )?, + vb.pp("ffn_up_exps").get( + (num_experts, cfg.moe_intermediate_size, cfg.hidden_size), + "weight", + )?, + vb.pp("ffn_down_exps").get( + (num_experts, cfg.hidden_size, cfg.moe_intermediate_size), + "weight", + )?, + ) + }; + + Ok(Self { + gate, + gate_experts, + up_experts, + down_experts, + act: cfg.act, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + // all_reduce: AllReduce::new(comm), + // world_size: 1, + dtype, + }) + } + + pub fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result { + let (batch, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let (num_tokens, hidden_dim) = xs.dims2()?; + let original_dtype = xs.dtype(); + let xs = if xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32)? + } else { + xs.to_owned() + }; + + let router_logits = self.gate.forward(&xs)?; + + let routing_weights = + candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?; + + let topk_ids = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + + let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?; + + if self.norm_topk_prob { + topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?; + } + + let (expert_ids, sorted_token_ids) = if is_prefill { + // For long-context (32K+), need to use custom sort kernel + // #[cfg(feature = "cuda")] + // { + // use attention_rs::sort::ArgSortOp; + // topk_ids.flatten_all()?.sort(true)? + // } + // #[cfg(not(feature = "cuda"))] + topk_ids.flatten_all()?.sort_last_dim(true)? + } else { + topk_ids.flatten_all()?.sort_last_dim(true)? + }; + + let ys = { + let gate = moe::moe_gemm_gguf( + &xs, + &self.gate_experts, + &None, + &sorted_token_ids, + &expert_ids, + self.num_experts_per_tok, + is_prefill, + self.dtype, + )?; + let up = moe::moe_gemm_gguf( + &xs, + &self.up_experts, + &None, + &sorted_token_ids, + &expert_ids, + self.num_experts_per_tok, + is_prefill, + self.dtype, + )?; + + let down_inputs = (up * gate.apply(&self.act)?)?; + moe::moe_gemm_gguf( + &down_inputs, + &self.down_experts, + &Some(topk_weights), + &sorted_token_ids, + &expert_ids, + self.num_experts_per_tok, + is_prefill, + self.dtype, + )? + }; + let mut ys = ys.reshape((num_tokens, (), hidden_dim))?.sum(D::Minus2)?; + if ys.dtype() != original_dtype { + ys = ys.to_dtype(original_dtype)?; + } + Ok(ys.reshape((batch, seq_len, hidden_dim))?) + } +} diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index b2b062a9d7..bae7699a09 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,3 +1,4 @@ +pub mod fused_moe; pub mod generation; pub mod models; pub mod object_detection; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e77ba4a36f..2d93833581 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -94,6 +94,7 @@ pub mod quantized_phi; pub mod quantized_phi3; pub mod quantized_qwen2; pub mod quantized_qwen3; +pub mod quantized_qwen3_moe; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 5d9f414658..85ccbb0edd 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -14,32 +14,32 @@ use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module}; use std::io::{Read, Seek}; use std::sync::Arc; -struct Gguf { +pub struct Gguf { ct: gguf_file::Content, reader: R, device: Device, } impl Gguf { - fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { + pub fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { Self { ct, reader, device } } - fn qmatmul(&mut self, name: &str) -> Result { + pub fn qmatmul(&mut self, name: &str) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; QMatMul::from_weights(ws.into()) } - fn rms_norm(&mut self, name: &str, eps: f64) -> Result { + pub fn rms_norm(&mut self, name: &str, eps: f64) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; RmsNorm::from_qtensor(ws, eps) } - fn metadata(&self) -> &std::collections::HashMap { + pub fn metadata(&self) -> &std::collections::HashMap { &self.ct.metadata } - fn tensor(&mut self, name: &str) -> Result { + pub fn tensor(&mut self, name: &str) -> Result { self.ct.tensor(&mut self.reader, name, &self.device) } } @@ -81,13 +81,13 @@ impl Module for MlpWeights { } #[derive(Debug, Clone)] -struct RotaryEmbedding { +pub struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { - fn new( + pub fn new( dtype: DType, head_dim: usize, max_position_embeddings: usize, @@ -113,7 +113,7 @@ impl RotaryEmbedding { } /// Apply RoPE (q, k shape: B x H x L x D) - fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs new file mode 100644 index 0000000000..37de42081c --- /dev/null +++ b/candle-transformers/src/models/quantized_qwen3_moe.rs @@ -0,0 +1,465 @@ +use super::quantized_qwen3::{Gguf, RotaryEmbedding}; +use super::with_tracing::QMatMul; +use crate::fused_moe::{FusedMoeGGUF, MoeCfg}; +use crate::quantized_nn::RmsNorm; +use crate::utils::repeat_kv; +use candle::quantized::gguf_file; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::kv_cache::ConcatKvCache; +use candle_nn::Linear; +use candle_nn::{Embedding, Module}; +use std::sync::Arc; +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +enum MoeOrMlp { + FusedMoe(FusedMoeGGUF), + Mlp(Mlp), +} + +impl MoeOrMlp { + fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result { + match self { + Self::Mlp(m) => m.forward(xs), + Self::FusedMoe(m) => m.forward(xs, is_prefill), + } + } +} + +pub struct QuantizedAttention { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_bq: Option, + attention_bk: Option, + attention_bv: Option, + attention_wo: QMatMul, + q_norm: Option, + k_norm: Option, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + num_kv_groups: usize, + rotary_emb: Arc, + dtype: DType, + kv_cache: ConcatKvCache, +} + +impl QuantizedAttention { + pub fn new( + gg: &mut Gguf, + prefix: &str, + dtype: DType, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + rms_norm_eps: f64, + device: &Device, + rotary_emb: Arc, + ) -> Result { + let num_kv_groups = num_heads / num_kv_heads; + let attention_wq = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?; + let attention_wk = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?; + let attention_wv = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?; + + let attention_bq = gg.tensor(&format!("{prefix}.attn_q.bias")); + let attention_bk = gg.tensor(&format!("{prefix}.attn_k.bias")); + let attention_bv = gg.tensor(&format!("{prefix}.attn_v.bias")); + + let attention_bq = if attention_bq.is_ok() { + Some( + attention_bq + .unwrap() + .dequantize(device)? + .to_dtype(DType::F32)?, + ) + } else { + None + }; + + let attention_bk = if attention_bk.is_ok() { + Some( + attention_bk + .unwrap() + .dequantize(device)? + .to_dtype(DType::F32)?, + ) + } else { + None + }; + + let attention_bv = if attention_bv.is_ok() { + Some( + attention_bv + .unwrap() + .dequantize(device)? + .to_dtype(DType::F32)?, + ) + } else { + None + }; + + let attention_wo = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?; + let q_norm = Some(gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?); + let k_norm = Some(gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?); + let kv_cache = ConcatKvCache::new(2); + Ok(QuantizedAttention { + attention_wq, + attention_wk, + attention_wv, + attention_bq, + attention_bk, + attention_bv, + attention_wo, + q_norm, + k_norm, + n_head: num_heads, + n_kv_head: num_kv_heads, + head_dim, + num_kv_groups, + rotary_emb: rotary_emb.clone(), + dtype, + kv_cache, + }) + } + + pub fn forward( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + input_pos: usize, + ) -> Result { + let (b, seq_len, _) = x.dims3()?; + let in_dtype = x.dtype(); + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = if self.attention_bq.is_some() { + q.broadcast_add(self.attention_bq.as_ref().unwrap())? + } else { + q + }; + + let k = if self.attention_bk.is_some() { + k.broadcast_add(self.attention_bk.as_ref().unwrap())? + } else { + k + }; + + let v = if self.attention_bv.is_some() { + v.broadcast_add(self.attention_bv.as_ref().unwrap())? + } else { + v + }; + + let q = q + .reshape((1, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((1, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let v = v + .reshape((1, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let (q, k) = if let (Some(q_norm), Some(k_norm)) = (&self.q_norm, &self.k_norm) { + // Per‑head RMSNorm in qwen3 + let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later + let k_flat = k.flatten(0, 2)?; + + // q_norm and k_norm weights stored in f32 format in qwen3 gguf + let q_flat = q_norm.forward(&q_flat)?; + let k_flat = k_norm.forward(&k_flat)?; + + let q = q_flat.reshape((1, self.n_head, seq_len, self.head_dim))?; + let k = k_flat.reshape((1, self.n_kv_head, seq_len, self.head_dim))?; + + (q, k) + } else { + (q, k) + }; + + let (q, k, v) = ( + q.to_dtype(self.dtype)?, + k.to_dtype(self.dtype)?, + v.to_dtype(self.dtype)?, + ); + + let (q, k) = self.rotary_emb.apply(&q, &k, input_pos)?; + + let (k, v) = self.kv_cache.append(&k, &v)?; + + let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; + let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; + + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + if let Some(m) = mask { + let m_dtype = m.dtype(); + let scores_dtype = scores.dtype(); + let mask = if m_dtype != scores_dtype { + m.to_dtype(scores_dtype)? + } else { + m.clone() + }; + scores = scores.broadcast_add(&mask)?; + } + + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + let reshaped_ctx = + ctx.transpose(1, 2)? + .reshape((b, seq_len, self.n_head * self.head_dim))?; + + self.attention_wo.forward(&reshaped_ctx.to_dtype(in_dtype)?) + } +} + +struct LayerWeights { + self_attn: QuantizedAttention, + attention_norm: RmsNorm, + mlp: MoeOrMlp, + ffn_norm: RmsNorm, +} + +impl LayerWeights { + fn forward_attn(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + self.self_attn.forward(x, mask, offset) + } +} + +pub struct GGUFQWenMoE { + tok_embeddings: Embedding, + layers: Vec, + norm: RmsNorm, + output: QMatMul, + dtype: DType, + device: Device, +} + +impl GGUFQWenMoE { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + dtype: DType, + ) -> Result { + let mut gg = Gguf::new(ct, reader, device.clone()); + let md_get = |s: &str| match gg.metadata().get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + let arch = md_get("general.architecture")?.to_string()?; + + let head_count = + md_get(format!("{arch}.attention.head_count").as_str())?.to_u32()? as usize; + let head_count_kv = + md_get(format!("{arch}.attention.head_count_kv").as_str())?.to_u32()? as usize; + + let head_dim = md_get(format!("{arch}.attention.key_length").as_str()); + let embedding_length = + md_get(format!("{arch}.embedding_length").as_str())?.to_u32()? as usize; + let head_dim = if head_dim.is_ok() { + head_dim.unwrap().to_u32()? as usize + } else { + embedding_length / head_count + }; + let context_length = md_get(format!("{arch}.context_length").as_str())?.to_u32()? as usize; + let block_count = md_get(format!("{arch}.block_count").as_str())?.to_u32()? as usize; + let rms_norm_eps = + md_get(format!("{arch}.attention.layer_norm_rms_epsilon").as_str())?.to_f32()? as f64; + let rope_freq_base = md_get(format!("{arch}.rope.freq_base").as_str()) + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let expert_shared_feed_forward_length = + md_get(format!("{arch}.expert_shared_feed_forward_length").as_str()); + let shared_expert_intermediate_size = match expert_shared_feed_forward_length { + Ok(length) => { + if length.to_u32()? > 0 { + Some(length.to_u32()? as usize) + } else { + None + } + } + _ => None, + }; + + let moe_cfg = MoeCfg { + moe_intermediate_size: md_get(format!("{arch}.expert_feed_forward_length").as_str())? + .to_u32()? as usize, + num_experts: md_get(format!("{arch}.expert_count").as_str())?.to_u32()? as usize, + norm_topk_prob: shared_expert_intermediate_size.is_none(), + num_experts_per_tok: md_get(format!("{arch}.expert_used_count").as_str())?.to_u32()? + as usize, + hidden_size: head_dim, + act: candle_nn::Activation::Silu, + decoder_sparse_step: None, + }; + + let tok_embeddings = gg.tensor("token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?; + let output = match gg.qmatmul("output.weight") { + Ok(v) => v, + _ => { + // use tie_word_embeddings + gg.qmatmul("token_embd.weight")? + } + }; + + let rotary_emb = Arc::new(RotaryEmbedding::new( + dtype, + head_dim, + context_length, + rope_freq_base as f64, + device, + )?); + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let mlp = if moe_cfg.num_experts > 0 + && (layer_idx + 1) % moe_cfg.decoder_sparse_step.unwrap_or(1) == 0 + { + let gate_ws = gg + .tensor(&format!("{prefix}.ffn_gate_inp.weight"))? + .dequantize(&device)? + .to_dtype(DType::F32)?; + let gate = Linear::new(gate_ws, None); + let gate_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_gate_exps.weight"))?); + let up_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_up_exps.weight"))?); + let down_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_down_exps.weight"))?); + let moe = FusedMoeGGUF { + gate, + gate_experts, + up_experts, + down_experts, + act: candle_nn::Activation::Silu, + norm_topk_prob: moe_cfg.norm_topk_prob, + num_experts_per_tok: moe_cfg.num_experts_per_tok, + dtype, + }; + + MoeOrMlp::FusedMoe(moe) + } else { + let mlp = { + let feed_forward_w1 = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?; + Mlp { + feed_forward_w1, + feed_forward_w2, + feed_forward_w3, + } + }; + MoeOrMlp::Mlp(mlp) + }; + + let attention_norm = + gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?; + let ffn_norm = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?; + + let self_attn = QuantizedAttention::new( + &mut gg, + &prefix, + dtype, + head_count, + head_count_kv, + head_dim, + rms_norm_eps, + device, + rotary_emb.clone(), + )?; + layers.push(LayerWeights { + self_attn, + attention_norm, + mlp, + ffn_norm, + }); + } + + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output, + dtype, + device: device.clone(), + }) + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, x: &Tensor, offset: usize) -> Result { + let mut xs = self.tok_embeddings.forward(x)?; + let (b, l) = x.dims2()?; + + let causal_mask = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in self.layers.iter_mut() { + let x = xs; + let residual = &x; + + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, causal_mask.as_ref(), offset)?; + let x = (attn + residual)?; + + // MLP + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let x = layer.mlp.forward(&x, causal_mask.is_some())?; + let x = (x + residual)?; + xs = x + } + + let xs = xs.narrow(1, l - 1, 1)?; + let xs = self.norm.forward(&xs)?; + self.output.forward(&xs)?.to_dtype(DType::F32)?.squeeze(1) + } +} diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index b76ce92de4..0576b4c075 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -1,6 +1,9 @@ -use crate::models::{ - qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding}, - with_tracing::{linear_no_bias, Linear, RmsNorm}, +use crate::{ + fused_moe::{FusedMoe, MoeCfg}, + models::{ + qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding}, + with_tracing::{linear_no_bias, Linear, RmsNorm}, + }, }; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; @@ -176,14 +179,16 @@ impl Module for Qwen3SparseMoeBlock { #[derive(Debug, Clone)] enum Qwen3FeedForward { Mlp(Qwen3MLP), - MoE(Qwen3SparseMoeBlock), + NaiveMoE(Qwen3SparseMoeBlock), + FusedMoE(FusedMoe), } -impl Module for Qwen3FeedForward { - fn forward(&self, xs: &Tensor) -> Result { +impl Qwen3FeedForward { + fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result { match self { Self::Mlp(m) => m.forward(xs), - Self::MoE(m) => m.forward(xs), + Self::NaiveMoE(m) => m.forward(xs), + Self::FusedMoE(m) => m.forward(xs, is_prefill), } } } @@ -205,10 +210,24 @@ impl DecoderLayer { ) -> Result { let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?; + let moe_cfg = MoeCfg { + hidden_size: cfg.hidden_size, + num_experts: cfg.num_experts, + num_experts_per_tok: cfg.num_experts_per_tok, + moe_intermediate_size: cfg.moe_intermediate_size, + norm_topk_prob: cfg.norm_topk_prob, + act: cfg.hidden_act, + decoder_sparse_step: None, + }; // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1).is_multiple_of(cfg.decoder_sparse_step) { - Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) + if cfg!(feature = "cuda") { + // Use fused MoE kernel on CUDA + Qwen3FeedForward::FusedMoE(FusedMoe::new(&moe_cfg, vb.pp("mlp"), vb.dtype())?) + } else { + Qwen3FeedForward::NaiveMoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) + } } else { Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?) }; @@ -233,7 +252,7 @@ impl DecoderLayer { let h = self.self_attn.forward(&h, mask, offset)?; let x = (x + h)?; let h2 = self.ln2.forward(&x)?; - let h2 = h2.apply(&self.feed_forward)?; + let h2 = self.feed_forward.forward(&h2, mask.is_some())?; x + h2 } From 9e8e9f0b1ed457735a133f661bf8d3f6972fb540 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 3 Dec 2025 02:33:50 +0000 Subject: [PATCH 2/3] Typo and cargo clippy fix --- candle-core/benches/benchmarks/mod.rs | 4 ++-- candle-core/benches/benchmarks/qmatmul.rs | 2 +- candle-core/benches/benchmarks/unary.rs | 2 +- candle-core/src/quantized/imatrix_file.rs | 2 +- candle-examples/examples/qwen/main.rs | 2 +- candle-nn/benches/benchmarks/mod.rs | 4 ++-- candle-nn/src/moe.rs | 11 ++++++----- candle-transformers/src/fused_moe.rs | 8 ++++---- candle-transformers/src/models/quantized_qwen3_moe.rs | 3 ++- 9 files changed, 20 insertions(+), 18 deletions(-) diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index bc98eb2ff8..167c29eb7f 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -28,13 +28,13 @@ impl BenchDevice for Device { return Ok(device.synchronize()?); } #[cfg(not(feature = "cuda"))] - panic!("Cuda device without cuda feature enabled: {:?}", device) + panic!("Cuda device without cuda feature enabled: {device:?}") } Device::Metal(device) => { #[cfg(feature = "metal")] return device.wait_until_completed(); #[cfg(not(feature = "metal"))] - panic!("Metal device without metal feature enabled: {:?}", device) + panic!("Metal device without metal feature enabled: {device:?}") } } } diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs index 6b46fb83e9..be1d2ad021 100644 --- a/candle-core/benches/benchmarks/qmatmul.rs +++ b/candle-core/benches/benchmarks/qmatmul.rs @@ -32,7 +32,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) { let flops = b * m * n * k; - let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype))); + let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{dtype:?}"))); group.sample_size(200); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs index 145878f206..1072a7fc5a 100644 --- a/candle-core/benches/benchmarks/unary.rs +++ b/candle-core/benches/benchmarks/unary.rs @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); for device in handler.devices { for dtype in [DType::F32, DType::BF16, DType::F16] { - let name = format!("sqrt_{:?}", dtype); + let name = format!("sqrt_{dtype:?}"); run_unary_benchmark(c, &device, dtype, &name); } } diff --git a/candle-core/src/quantized/imatrix_file.rs b/candle-core/src/quantized/imatrix_file.rs index db434f7f3e..ed228b74ce 100644 --- a/candle-core/src/quantized/imatrix_file.rs +++ b/candle-core/src/quantized/imatrix_file.rs @@ -30,7 +30,7 @@ pub fn load_imatrix>(fname: P) -> Result let n_entries = cursor .read_i32::() - .map_err(|e| crate::Error::msg(format!("Failed to read number of entries: {}", e)))? + .map_err(|e| crate::Error::msg(format!("Failed to read number of entries: {e}")))? as usize; if n_entries < 1 { diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 4c6a1e76d6..f6765411c1 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -308,7 +308,7 @@ fn main() -> Result<()> { { candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")? } else { - vec!["model.safetensors".into()].into() + vec!["model.safetensors".into()] } } None => match args.model { diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index 2bff47cb12..fcc83b8e3f 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -21,13 +21,13 @@ impl BenchDevice for Device { return Ok(device.synchronize()?); } #[cfg(not(feature = "cuda"))] - panic!("Cuda device without cuda feature enabled: {:?}", device) + panic!("Cuda device without cuda feature enabled: {device:?}") } Device::Metal(device) => { #[cfg(feature = "metal")] return device.wait_until_completed(); #[cfg(not(feature = "metal"))] - panic!("Metal device without metal feature enabled: {:?}", device) + panic!("Metal device without metal feature enabled: {device:?}") } } } diff --git a/candle-nn/src/moe.rs b/candle-nn/src/moe.rs index 4f62af8f66..2f2bd9c2db 100644 --- a/candle-nn/src/moe.rs +++ b/candle-nn/src/moe.rs @@ -46,7 +46,7 @@ pub fn moe_gemm( DType::F16 => 0, DType::BF16 => 1, _ => { - candle::bail!("moe_gemm_wmma only accept f16/bf16 inputs!") + candle::bail!("moe_gemm_wmma only accepts f16/bf16 inputs") } }; @@ -146,7 +146,7 @@ pub fn moe_gemm( is_prefill, ), _ => { - candle::bail!("moe_gemm only accept f16/bf16 inputs!") + candle::bail!("moe_gemm only accepts f16/bf16 inputs") } } } @@ -161,7 +161,7 @@ pub fn moe_gemm( _: usize, _: bool, ) -> Result { - candle::bail!("moe_gemm is not implemented on this platform!") + candle::bail!("moe_gemm is only implemented for the cuda backend") } #[cfg(feature = "cuda")] @@ -329,12 +329,13 @@ pub fn moe_gemm_gguf( dtype, ), _ => { - candle::bail!("moe_gemm_gguf only accept f16/bf16 inputs!") + candle::bail!("moe_gemm_gguf only accepts f32 inputs") } } } #[cfg(not(feature = "cuda"))] +#[allow(clippy::too_many_arguments)] pub fn moe_gemm_gguf( _: &Tensor, _: &QTensor, @@ -345,5 +346,5 @@ pub fn moe_gemm_gguf( _: bool, _: candle::DType, ) -> Result { - candle::bail!("moe_gemm_gguf is not implemented on this platform!") + candle::bail!("moe_gemm_gguf is only implemented for the cuda backend") } diff --git a/candle-transformers/src/fused_moe.rs b/candle-transformers/src/fused_moe.rs index 91eb5f217b..da2c6cf912 100644 --- a/candle-transformers/src/fused_moe.rs +++ b/candle-transformers/src/fused_moe.rs @@ -40,7 +40,7 @@ impl FusedMoe { //pack experts for i in 0..num_experts { - let experts_vb = experts_vb.pp(format!("{}", i).as_str()); + let experts_vb = experts_vb.pp(format!("{i}").as_str()); let (gate_up_expert, down_expert) = { // n x k format @@ -156,7 +156,7 @@ impl FusedMoe { .reshape((num_tokens, (), hidden_dim))? .sum(D::Minus2)?; - Ok(ys.reshape((batch, seq_len, hidden_dim))?) + ys.reshape((batch, seq_len, hidden_dim)) } } @@ -183,7 +183,7 @@ impl FusedMoeGGUF { let gate_ws = vb .pp("ffn_gate_inp") .get((num_experts, cfg.hidden_size), "weight")? - .dequantize_f16(&vb.device())? + .dequantize(vb.device())? .to_dtype(DType::F32)?; let gate = Linear::new(gate_ws, None); @@ -297,6 +297,6 @@ impl FusedMoeGGUF { if ys.dtype() != original_dtype { ys = ys.to_dtype(original_dtype)?; } - Ok(ys.reshape((batch, seq_len, hidden_dim))?) + ys.reshape((batch, seq_len, hidden_dim)) } } diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs index 37de42081c..0e2749be90 100644 --- a/candle-transformers/src/models/quantized_qwen3_moe.rs +++ b/candle-transformers/src/models/quantized_qwen3_moe.rs @@ -59,6 +59,7 @@ pub struct QuantizedAttention { } impl QuantizedAttention { + #[allow(clippy::too_many_arguments)] pub fn new( gg: &mut Gguf, prefix: &str, @@ -340,7 +341,7 @@ impl GGUFQWenMoE { { let gate_ws = gg .tensor(&format!("{prefix}.ffn_gate_inp.weight"))? - .dequantize(&device)? + .dequantize(device)? .to_dtype(DType::F32)?; let gate = Linear::new(gate_ws, None); let gate_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_gate_exps.weight"))?); From 0a3988fa439268101596b0a81149d5c8226da1e4 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 18 Dec 2025 10:14:09 +0000 Subject: [PATCH 3/3] Clippy fix --- candle-core/benches/benchmarks/binary.rs | 2 +- candle-core/benches/benchmarks/unary.rs | 2 +- .../src/models/quantized_qwen3_moe.rs | 31 +++++-------------- 3 files changed, 10 insertions(+), 25 deletions(-) diff --git a/candle-core/benches/benchmarks/binary.rs b/candle-core/benches/benchmarks/binary.rs index 46e2cf7f7f..a4953c6332 100644 --- a/candle-core/benches/benchmarks/binary.rs +++ b/candle-core/benches/benchmarks/binary.rs @@ -48,7 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); for device in handler.devices { for dtype in [DType::F32, DType::BF16, DType::F16] { - let name = format!("binary_mul_{:?}", dtype); + let name = format!("binary_mul_{dtype:?}"); run_unary_benchmark(c, &device, dtype, &name); } } diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs index 287e2341f7..65723bb3fd 100644 --- a/candle-core/benches/benchmarks/unary.rs +++ b/candle-core/benches/benchmarks/unary.rs @@ -89,7 +89,7 @@ fn criterion_benchmark(c: &mut Criterion) { run_cast_benchmark(c, &device, dtype, to_dtype, &name); } for dtype in [DType::F32, DType::BF16, DType::F16] { - let name = format!("sqrt_{:?}", dtype); + let name = format!("sqrt_{dtype:?}"); run_unary_benchmark(c, &device, dtype, &name); } } diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs index 0e2749be90..2daa84e062 100644 --- a/candle-transformers/src/models/quantized_qwen3_moe.rs +++ b/candle-transformers/src/models/quantized_qwen3_moe.rs @@ -80,35 +80,20 @@ impl QuantizedAttention { let attention_bk = gg.tensor(&format!("{prefix}.attn_k.bias")); let attention_bv = gg.tensor(&format!("{prefix}.attn_v.bias")); - let attention_bq = if attention_bq.is_ok() { - Some( - attention_bq - .unwrap() - .dequantize(device)? - .to_dtype(DType::F32)?, - ) + let attention_bq = if let Ok(attention_bq) = attention_bq { + Some(attention_bq.dequantize(device)?.to_dtype(DType::F32)?) } else { None }; - let attention_bk = if attention_bk.is_ok() { - Some( - attention_bk - .unwrap() - .dequantize(device)? - .to_dtype(DType::F32)?, - ) + let attention_bk = if let Ok(attention_bk) = attention_bk { + Some(attention_bk.dequantize(device)?.to_dtype(DType::F32)?) } else { None }; - let attention_bv = if attention_bv.is_ok() { - Some( - attention_bv - .unwrap() - .dequantize(device)? - .to_dtype(DType::F32)?, - ) + let attention_bv = if let Ok(attention_bv) = attention_bv { + Some(attention_bv.dequantize(device)?.to_dtype(DType::F32)?) } else { None }; @@ -278,8 +263,8 @@ impl GGUFQWenMoE { let head_dim = md_get(format!("{arch}.attention.key_length").as_str()); let embedding_length = md_get(format!("{arch}.embedding_length").as_str())?.to_u32()? as usize; - let head_dim = if head_dim.is_ok() { - head_dim.unwrap().to_u32()? as usize + let head_dim = if let Ok(head_dim) = head_dim { + head_dim.to_u32()? as usize } else { embedding_length / head_count };