diff --git a/README.md b/README.md
index 632afdd782..4a62a27593 100644
--- a/README.md
+++ b/README.md
@@ -92,6 +92,7 @@ We also provide some command line based examples using state of the art models:
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
+- [Quantized Qwen3 MoE](./candle-examples/examples/quantized-qwen3-moe/): support gguf quantized models of Qwen3 MoE models.
@@ -190,6 +191,7 @@ And then head over to
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
+- [`vllm.rs`](https://github.com/guoqingbao/vllm.rs): A minimalist vLLM implementation in Rust based on Candle.
If you have an addition to this list, please submit a pull request.
@@ -220,7 +222,7 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- - Qwen1.5, Qwen1.5 MoE.
+ - Qwen1.5, Qwen1.5 MoE, Qwen3 MoE.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
@@ -228,6 +230,7 @@ If you have an addition to this list, please submit a pull request.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
+ - Qwen3 MoE (16B-A3B, 32B-A3B)
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
diff --git a/candle-core/benches/benchmarks/binary.rs b/candle-core/benches/benchmarks/binary.rs
index 46e2cf7f7f..a4953c6332 100644
--- a/candle-core/benches/benchmarks/binary.rs
+++ b/candle-core/benches/benchmarks/binary.rs
@@ -48,7 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
- let name = format!("binary_mul_{:?}", dtype);
+ let name = format!("binary_mul_{dtype:?}");
run_unary_benchmark(c, &device, dtype, &name);
}
}
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
index 3b45a83e5f..9cc6767a4d 100644
--- a/candle-core/benches/benchmarks/mod.rs
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -29,13 +29,13 @@ impl BenchDevice for Device {
return Ok(device.synchronize()?);
}
#[cfg(not(feature = "cuda"))]
- panic!("Cuda device without cuda feature enabled: {:?}", device)
+ panic!("Cuda device without cuda feature enabled: {device:?}")
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return device.wait_until_completed();
#[cfg(not(feature = "metal"))]
- panic!("Metal device without metal feature enabled: {:?}", device)
+ panic!("Metal device without metal feature enabled: {device:?}")
}
}
}
diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs
index 6b46fb83e9..be1d2ad021 100644
--- a/candle-core/benches/benchmarks/qmatmul.rs
+++ b/candle-core/benches/benchmarks/qmatmul.rs
@@ -32,7 +32,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let flops = b * m * n * k;
- let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
+ let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{dtype:?}")));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs
index 287e2341f7..65723bb3fd 100644
--- a/candle-core/benches/benchmarks/unary.rs
+++ b/candle-core/benches/benchmarks/unary.rs
@@ -89,7 +89,7 @@ fn criterion_benchmark(c: &mut Criterion) {
run_cast_benchmark(c, &device, dtype, to_dtype, &name);
}
for dtype in [DType::F32, DType::BF16, DType::F16] {
- let name = format!("sqrt_{:?}", dtype);
+ let name = format!("sqrt_{dtype:?}");
run_unary_benchmark(c, &device, dtype, &name);
}
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index a4d5d6cb97..3c3ffb1097 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -1031,7 +1031,7 @@ impl UnaryOpT for Relu {
pub struct BackpropOp(Option);
impl BackpropOp {
- pub(crate) fn none() -> Self {
+ pub fn none() -> Self {
BackpropOp(None)
}
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 3faf9f695f..563c9ce3db 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -742,6 +742,11 @@ impl QCudaStorage {
.memcpy_dtoh(&self.data.inner.slice(..self.data.len), &mut out)?;
Ok(out)
}
+
+ pub fn device_ptr(&self) -> Result<*const u8> {
+ use cudarc::driver::DevicePtr;
+ Ok(self.data.inner.device_ptr(self.data.inner.stream()).0 as *const u8)
+ }
}
impl QCudaStorage {
diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs
index 7194439a09..04f19f9fcb 100644
--- a/candle-core/src/quantized/dummy_cuda.rs
+++ b/candle-core/src/quantized/dummy_cuda.rs
@@ -54,6 +54,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ pub fn device_ptr(&self) -> Result<*const u8> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
pub fn storage_size_in_bytes(&self) -> usize {
0
}
diff --git a/candle-core/src/quantized/imatrix_file.rs b/candle-core/src/quantized/imatrix_file.rs
index db434f7f3e..ed228b74ce 100644
--- a/candle-core/src/quantized/imatrix_file.rs
+++ b/candle-core/src/quantized/imatrix_file.rs
@@ -30,7 +30,7 @@ pub fn load_imatrix>(fname: P) -> Result
let n_entries = cursor
.read_i32::()
- .map_err(|e| crate::Error::msg(format!("Failed to read number of entries: {}", e)))?
+ .map_err(|e| crate::Error::msg(format!("Failed to read number of entries: {e}")))?
as usize;
if n_entries < 1 {
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index cee8ccc2ad..7316d29871 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -239,6 +239,15 @@ impl QStorage {
QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
}
}
+
+ pub fn device_ptr(&self) -> Result<*const u8> {
+ match self {
+ QStorage::Cuda(storage) => storage.device_ptr(),
+ QStorage::Metal(_) | QStorage::Cpu(_) => {
+ crate::bail!("not implemented");
+ }
+ }
+ }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -670,6 +679,15 @@ impl QTensor {
}
}
}
+
+ pub fn device_ptr(&self) -> Result<*const u8> {
+ match &self.storage {
+ QStorage::Cuda(storage) => storage.device_ptr(),
+ QStorage::Metal(_) | QStorage::Cpu(_) => {
+ crate::bail!("not implemented");
+ }
+ }
+ }
}
#[derive(Clone, Debug)]
diff --git a/candle-examples/examples/quantized-qwen3-moe/README.md b/candle-examples/examples/quantized-qwen3-moe/README.md
new file mode 100644
index 0000000000..8f82051a31
--- /dev/null
+++ b/candle-examples/examples/quantized-qwen3-moe/README.md
@@ -0,0 +1,18 @@
+# candle-quantized-qwen3-moe
+
+[Qwen3 MoE GGUF]((https://huggingface.co/unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF)) contains the GGUF format of Qwen3 32B MoE models, developed by Alibaba Cloud.
+
+## Running the example
+
+```bash
+# Local GGUF file
+cargo run --features cuda --example quantized-qwen3-moe --release -- --model /path/Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --prompt "Write a function to count prime numbers up to N."
+```
+
+Models available via `--which` argument: 16b_q2k, 16b_q4k, 16b_q6k, 16b_q80; 32b_q2k, 32b_q4k, 32b_q6k, 32b_q80;
+
+```bash
+# Obtained from Huggingface
+cargo run --features cuda --example quantized-qwen3-moe --release -- --which 32b_q4k --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
+```
+
diff --git a/candle-examples/examples/quantized-qwen3-moe/main.rs b/candle-examples/examples/quantized-qwen3-moe/main.rs
new file mode 100644
index 0000000000..8fdfca39ef
--- /dev/null
+++ b/candle-examples/examples/quantized-qwen3-moe/main.rs
@@ -0,0 +1,357 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::{Parser, ValueEnum};
+use std::io::Write;
+use tokenizers::Tokenizer;
+
+use candle::Tensor;
+use candle::{quantized::gguf_file, DType};
+use candle_transformers::generation::{LogitsProcessor, Sampling};
+
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE as Qwen3_MoE;
+
+const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
+
+#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
+enum Which {
+ #[value(name = "16b_q2k")]
+ W3_16bQ2K,
+ #[value(name = "16b_q4k")]
+ W3_16bQ4K,
+ #[value(name = "16b_q6k")]
+ W3_16bQ6K,
+ #[value(name = "16b_q80")]
+ W3_16bQ80,
+ #[value(name = "32b_q2k")]
+ W3_32bQ2K,
+ #[value(name = "32b_q4k")]
+ W3_32bQ4K,
+ #[value(name = "32b_q6k")]
+ W3_32bQ6K,
+ #[value(name = "32b_q80")]
+ W3_32bQ80,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
+ #[arg(long)]
+ model: Option,
+
+ /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
+ /// and 'chat' for an interactive model where history of previous prompts and generated tokens
+ /// is preserved.
+ #[arg(long)]
+ prompt: Option,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(short = 'n', long, default_value_t = 1000)]
+ sample_len: usize,
+
+ /// The tokenizer config in json format.
+ #[arg(long)]
+ tokenizer: Option,
+
+ /// The temperature used to generate samples, use 0 for greedy sampling.
+ #[arg(long, default_value_t = 0.8)]
+ temperature: f64,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// Process prompt elements separately.
+ #[arg(long)]
+ split_prompt: bool,
+
+ /// Run on CPU rather than GPU even if a GPU is available.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+
+ /// The model size to use.
+ #[arg(long, default_value = "16b_q2k")]
+ which: Which,
+
+ #[arg(long, default_value = "bf16")]
+ dtype: String,
+}
+
+impl Args {
+ fn tokenizer(&self) -> anyhow::Result {
+ let tokenizer_path = match &self.tokenizer {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let repo = "Qwen/Qwen3-30B-A3B-Instruct-2507";
+ let api = api.model(repo.to_string());
+ api.get("tokenizer.json")?
+ }
+ };
+ Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
+ }
+
+ fn model(&self) -> anyhow::Result {
+ let model_path = match &self.model {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let (repo, filename, revision) = match self.which {
+ Which::W3_16bQ2K => (
+ "unsloth/Qwen3-16B-A3B-GGUF",
+ "Qwen3-16B-A3B-Q2_K.gguf",
+ "main",
+ ),
+ Which::W3_16bQ4K => (
+ "unsloth/Qwen3-16B-A3B-GGUF",
+ "Qwen3-16B-A3B-Q4_K_M.gguf",
+ "main",
+ ),
+ Which::W3_16bQ6K => (
+ "unsloth/Qwen3-16B-A3B-GGUF",
+ "Qwen3-16B-A3B-Q6_K.gguf",
+ "main",
+ ),
+ Which::W3_16bQ80 => (
+ "unsloth/Qwen3-16B-A3B-GGUF",
+ "Qwen3-16B-A3B-Q8_0.gguf",
+ "main",
+ ),
+
+ Which::W3_32bQ2K => (
+ "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF",
+ "Qwen3-30B-A3B-Instruct-2507-Q2_K.gguf",
+ "main",
+ ),
+ Which::W3_32bQ4K => (
+ "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF",
+ "Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf",
+ "main",
+ ),
+ Which::W3_32bQ6K => (
+ "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF",
+ "Qwen3-30B-A3B-Instruct-2507-Q6_K.gguf",
+ "main",
+ ),
+ Which::W3_32bQ80 => (
+ "unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF",
+ "Qwen3-30B-A3B-Instruct-2507-Q8_0.gguf",
+ "main",
+ ),
+ };
+ let api = hf_hub::api::sync::Api::new()?;
+ api.repo(hf_hub::Repo::with_revision(
+ repo.to_string(),
+ hf_hub::RepoType::Model,
+ revision.to_string(),
+ ))
+ .get(filename)?
+ }
+ };
+ Ok(model_path)
+ }
+}
+
+fn format_size(size_in_bytes: usize) -> String {
+ if size_in_bytes < 1_000 {
+ format!("{size_in_bytes}B")
+ } else if size_in_bytes < 1_000_000 {
+ format!("{:.2}KB", size_in_bytes as f64 / 1e3)
+ } else if size_in_bytes < 1_000_000_000 {
+ format!("{:.2}MB", size_in_bytes as f64 / 1e6)
+ } else {
+ format!("{:.2}GB", size_in_bytes as f64 / 1e9)
+ }
+}
+
+fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature, args.repeat_penalty, args.repeat_last_n
+ );
+
+ let dtype = match args.dtype.as_str() {
+ "bf16" => DType::BF16,
+ "f16" => DType::F16, // Used for V100
+ _ => {
+ panic!("Not supported dtype!")
+ }
+ };
+
+ let model_path = args.model()?;
+ let mut file = std::fs::File::open(&model_path)?;
+ let start = std::time::Instant::now();
+ let device = candle_examples::device(args.cpu)?;
+
+ let mut model = {
+ let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
+ let mut total_size_in_bytes = 0;
+ for (_, tensor) in model.tensor_infos.iter() {
+ let elem_count = tensor.shape.elem_count();
+ total_size_in_bytes +=
+ elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
+ }
+ println!(
+ "loaded {:?} tensors ({}) in {:.2}s",
+ model.tensor_infos.len(),
+ &format_size(total_size_in_bytes),
+ start.elapsed().as_secs_f32(),
+ );
+ Qwen3_MoE::from_gguf(model, &mut file, &device, dtype)?
+ };
+ println!("model built");
+
+ let tokenizer = args.tokenizer()?;
+ let mut tos = TokenOutputStream::new(tokenizer);
+ let prompt_str = args
+ .prompt
+ .clone()
+ .unwrap_or_else(|| DEFAULT_PROMPT.to_string());
+
+ let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
+ print!("formatted prompt: {}", &prompt_str);
+
+ let tokens = tos
+ .tokenizer()
+ .encode(prompt_str, true)
+ .map_err(anyhow::Error::msg)?;
+
+ let tokens = tokens.get_ids();
+
+ let to_sample = args.sample_len.saturating_sub(1);
+
+ let mut all_tokens = vec![];
+
+ let mut logits_processor = {
+ let temperature = args.temperature;
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (args.top_k, args.top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(args.seed, sampling)
+ };
+
+ let start_prompt_processing = std::time::Instant::now();
+
+ let mut next_token = if !args.split_prompt {
+ let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, 0)?;
+ let logits = logits.squeeze(0)?;
+ logits_processor.sample(&logits)?
+ } else {
+ let mut next_token = 0;
+ for (pos, token) in tokens.iter().enumerate() {
+ let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, pos)?;
+ let logits = logits.squeeze(0)?;
+ next_token = logits_processor.sample(&logits)?
+ }
+ next_token
+ };
+
+ let prompt_dt = start_prompt_processing.elapsed();
+
+ all_tokens.push(next_token);
+
+ if let Some(t) = tos.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+
+ let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
+
+ let start_post_prompt = std::time::Instant::now();
+
+ let mut sampled = 0;
+ for index in 0..to_sample {
+ let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, tokens.len() + index)?;
+ let logits = logits.squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &all_tokens[start_at..],
+ )?
+ };
+ next_token = logits_processor.sample(&logits)?;
+ all_tokens.push(next_token);
+ if let Some(t) = tos.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ sampled += 1;
+ if next_token == eos_token {
+ break;
+ };
+ }
+
+ if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
+ print!("{rest}");
+ }
+
+ std::io::stdout().flush()?;
+ let dt = start_post_prompt.elapsed();
+ println!(
+ "\n\n{:4} prompt tokens processed: {:.2} token/s",
+ tokens.len(),
+ tokens.len() as f64 / prompt_dt.as_secs_f64(),
+ );
+ println!(
+ "{sampled:4} tokens generated: {:.2} token/s",
+ sampled as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+}
diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md
index d81cd6660a..92fa90e96a 100644
--- a/candle-examples/examples/qwen/README.md
+++ b/candle-examples/examples/qwen/README.md
@@ -50,3 +50,8 @@ $ cargo run --example qwen --features metal --release -- --prompt "Write a poem
> Their beauty lives where hearts can fly.
> 161 tokens generated (3.00 token/s)
```
+
+```shell
+# Local unquantized 32B MoE model (with Fused MoE kernel) (~80GB GPU memory)
+cargo run --example qwen --features cuda --release -- --prompt "Write a poem about butterflies. ." --model "3-moe-a3b" --weight-path /path/Qwen3-30B-A3B-Instruct-2507
+```
\ No newline at end of file
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs
index 796f3a1d1f..f6765411c1 100644
--- a/candle-examples/examples/qwen/main.rs
+++ b/candle-examples/examples/qwen/main.rs
@@ -217,7 +217,7 @@ struct Args {
tokenizer_file: Option,
#[arg(long)]
- weight_files: Option,
+ weight_path: Option,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
@@ -288,15 +288,29 @@ fn main() -> Result<()> {
RepoType::Model,
args.revision,
));
- let tokenizer_filename = match args.tokenizer_file {
- Some(file) => std::path::PathBuf::from(file),
- None => repo.get("tokenizer.json")?,
+
+ let tokenizer_filename = match (args.weight_path.as_ref(), args.tokenizer_file.as_ref()) {
+ (Some(_), Some(file)) => std::path::PathBuf::from(file),
+ (None, Some(file)) => std::path::PathBuf::from(file),
+ (Some(path), None) => std::path::Path::new(path).join("tokenizer.json"),
+ (None, None) => repo.get("tokenizer.json")?,
+ };
+ let config_file = match &args.weight_path {
+ Some(path) => std::path::Path::new(path).join("config.json"),
+ _ => repo.get("config.json")?,
};
- let filenames = match args.weight_files {
- Some(files) => files
- .split(',')
- .map(std::path::PathBuf::from)
- .collect::>(),
+
+ let filenames = match args.weight_path {
+ Some(path) => {
+ if std::path::Path::new(&path)
+ .join("model.safetensors.index.json")
+ .exists()
+ {
+ candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")?
+ } else {
+ vec!["model.safetensors".into()]
+ }
+ }
None => match args.model {
WhichModel::W0_5b
| WhichModel::W2_0_5b
@@ -324,7 +338,6 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
- let config_file = repo.get("config.json")?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() || device.is_metal() {
DType::BF16
diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml
index c8b8d3de18..f727cada5b 100644
--- a/candle-kernels/Cargo.toml
+++ b/candle-kernels/Cargo.toml
@@ -12,4 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
-bindgen_cuda = "0.1.5"
+bindgen_cuda = { git = "https://github.com/guoqingbao/bindgen_cuda.git", version= "0.1.7" }
diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs
index e1813cd010..035345f86c 100644
--- a/candle-kernels/build.rs
+++ b/candle-kernels/build.rs
@@ -7,10 +7,54 @@ fn main() {
println!("cargo::rerun-if-changed=src/cuda_utils.cuh");
println!("cargo::rerun-if-changed=src/binary_op_macros.cuh");
+ // Build for PTX
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let ptx_path = out_dir.join("ptx.rs");
- let builder = bindgen_cuda::Builder::default();
+ let mut builder = bindgen_cuda::Builder::default()
+ .arg("--expt-relaxed-constexpr")
+ .arg("-std=c++17")
+ .arg("-O3")
+ .arg("--use_fast_math");
println!("cargo::warning={builder:?}");
let bindings = builder.build_ptx().unwrap();
- bindings.write(ptx_path).unwrap();
+ bindings.write(&ptx_path).unwrap();
+
+ // Remove unwanted MOE PTX constants from ptx.rs
+ remove_lines(&ptx_path, &["MOE_GGUF", "MOE_WMMA", "MOE_WMMA_GGUF"]);
+
+ // Build for FFI binding (must use custom bindgen_cuda, which supports simutanously build PTX and lib)
+ let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
+ let mut is_target_msvc = false;
+ if let Ok(target) = std::env::var("TARGET") {
+ if target.contains("msvc") {
+ is_target_msvc = true;
+ builder = builder.arg("-D_USE_MATH_DEFINES");
+ }
+ }
+
+ if !is_target_msvc {
+ builder = builder.arg("-Xcompiler").arg("-fPIC");
+ }
+
+ let builder = builder.kernel_paths(vec![
+ "src/moe/moe_gguf.cu",
+ "src/moe/moe_wmma.cu",
+ "src/moe/moe_wmma_gguf.cu",
+ ]);
+ println!("cargo::warning={builder:?}");
+ builder.build_lib(out_dir.join("libmoe.a"));
+ println!("cargo:rustc-link-search={}", out_dir.display());
+ println!("cargo:rustc-link-lib=moe");
+ println!("cargo:rustc-link-lib=dylib=cudart");
+ println!("cargo:rustc-link-lib=stdc++");
+}
+
+fn remove_lines>(file: P, patterns: &[&str]) {
+ let content = std::fs::read_to_string(&file).unwrap();
+ let filtered = content
+ .lines()
+ .filter(|line| !patterns.iter().any(|p| line.contains(p)))
+ .collect::>()
+ .join("\n");
+ std::fs::write(file, filtered).unwrap();
}
diff --git a/candle-kernels/src/ffi.rs b/candle-kernels/src/ffi.rs
new file mode 100644
index 0000000000..ac50392721
--- /dev/null
+++ b/candle-kernels/src/ffi.rs
@@ -0,0 +1,56 @@
+use core::ffi::c_void;
+#[allow(dead_code)]
+extern "C" {
+ // for unquntized models
+ pub fn moe_gemm_wmma(
+ input: *const c_void, // device pointer [size_m, size_k]
+ weights: *const c_void, // device pointer [num_experts, size_n, size_k]
+ sorted_token_ids: *const i32, // device pointer [size_m]
+ expert_ids: *const i32, // host array [size_m] (expert id per sorted token)
+ topk_weights: *const f32,
+ output: *mut c_void, // device pointer [size_m, size_n]
+ expert_counts: *mut i32, // pre-allocated buffer [num_experts]
+ expert_offsets: *mut i32, // pre-allocated buffer [num_experts + 1]
+ num_experts: i32,
+ topk: i32,
+ size_m: i32,
+ size_n: i32,
+ size_k: i32,
+ dtype: i32, // 0=float16, 1=bf16 (for input/output)
+ is_prefill: bool,
+ stream: i64,
+ );
+
+ pub fn moe_gemm_gguf(
+ input: *const f32, // input [size_m, size_k]
+ weights: *const c_void, // weights [num_experts, size_n, size_k]
+ sorted_token_ids: *const i32,
+ expert_ids: *const i32,
+ topk_weights: *const f32, // device ptr or nullptr
+ output: *mut c_void, // float output [size_m, size_n]
+ num_experts: i32,
+ topk: i32,
+ size_m: i32,
+ size_n: i32,
+ size_k: i32,
+ gguf_dtype: i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weights)
+ stream: i64,
+ );
+
+ pub fn moe_gemm_gguf_prefill(
+ input: *const c_void, // input [size_m, size_k]
+ weights: *const u8, // weights [num_experts, size_n, size_k]
+ sorted_token_ids: *const i32,
+ expert_ids: *const i32, //must be host ptr
+ topk_weights: *const f32, // device ptr or nullptr
+ output: *mut c_void, // float output [size_m, size_n]
+ num_experts: i32,
+ topk: i32,
+ size_m: i32,
+ size_n: i32,
+ size_k: i32,
+ input_dtype: i32, // 0=f16, 1=bf16 (for inputs)
+ gguf_dtype: i32, //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weights)
+ stream: i64,
+ );
+}
diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs
index 9b66403475..cfc5732652 100644
--- a/candle-kernels/src/lib.rs
+++ b/candle-kernels/src/lib.rs
@@ -78,3 +78,5 @@ mdl!(REDUCE, Reduce);
mdl!(SORT, Sort);
mdl!(TERNARY, Ternary);
mdl!(UNARY, Unary);
+
+pub mod ffi;
diff --git a/candle-kernels/src/moe/gguf.cuh b/candle-kernels/src/moe/gguf.cuh
new file mode 100644
index 0000000000..3e50e9e9e8
--- /dev/null
+++ b/candle-kernels/src/moe/gguf.cuh
@@ -0,0 +1,1438 @@
+// Kernels adapted from llama.cpp ggml-cuda.cu
+// https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda.cu
+#include "cuda_fp16.h"
+#include "cuda_bf16.h"
+#include
+
+#define GGML_UNUSED(x) (void)(x)
+#define GGML_CUDA_ASSUME(x)
+
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
+#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
+
+#undef GGML_CUDA_F16
+#define GGML_CUDA_DMMV_X 32
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#define K_QUANTS_PER_ITERATION 2
+
+typedef uint16_t ggml_fp16_t;
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+ }
+ return x;
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+}
+
+static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
+ return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
+ return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+
+#define WARP_SIZE 32
+#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+
+#define CC_PASCAL 600
+#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_VOLTA 700
+#define CC_OFFSET_AMD 1000000
+#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
+#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
+#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+ return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= MIN_CC_DP4A
+ const int8_t * a8 = (const int8_t *) &a;
+ const int8_t * b8 = (const int8_t *) &b;
+ return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+
+#define MMQ_X_Q4_0_RDNA2 64
+#define MMQ_Y_Q4_0_RDNA2 128
+#define NWARPS_Q4_0_RDNA2 8
+#define MMQ_X_Q4_0_RDNA1 64
+#define MMQ_Y_Q4_0_RDNA1 64
+#define NWARPS_Q4_0_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q4_0_AMPERE 4
+#define MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 4
+#else
+#define MMQ_X_Q4_0_AMPERE 64
+#define MMQ_Y_Q4_0_AMPERE 128
+#define NWARPS_Q4_0_AMPERE 4
+#endif
+#define MMQ_X_Q4_0_PASCAL 64
+#define MMQ_Y_Q4_0_PASCAL 64
+#define NWARPS_Q4_0_PASCAL 8
+
+#define MMQ_X_Q4_1_RDNA2 64
+#define MMQ_Y_Q4_1_RDNA2 128
+#define NWARPS_Q4_1_RDNA2 8
+#define MMQ_X_Q4_1_RDNA1 64
+#define MMQ_Y_Q4_1_RDNA1 64
+#define NWARPS_Q4_1_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q4_1_AMPERE 4
+#define MMQ_Y_Q4_1_AMPERE 32
+#define NWARPS_Q4_1_AMPERE 4
+#else
+#define MMQ_X_Q4_1_AMPERE 64
+#define MMQ_Y_Q4_1_AMPERE 128
+#define NWARPS_Q4_1_AMPERE 4
+#endif
+#define MMQ_X_Q4_1_PASCAL 64
+#define MMQ_Y_Q4_1_PASCAL 64
+#define NWARPS_Q4_1_PASCAL 8
+
+#define MMQ_X_Q5_0_RDNA2 64
+#define MMQ_Y_Q5_0_RDNA2 128
+#define NWARPS_Q5_0_RDNA2 8
+#define MMQ_X_Q5_0_RDNA1 64
+#define MMQ_Y_Q5_0_RDNA1 64
+#define NWARPS_Q5_0_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q5_0_AMPERE 4
+#define MMQ_Y_Q5_0_AMPERE 32
+#define NWARPS_Q5_0_AMPERE 4
+#else
+#define MMQ_X_Q5_0_AMPERE 128
+#define MMQ_Y_Q5_0_AMPERE 64
+#define NWARPS_Q5_0_AMPERE 4
+#endif
+#define MMQ_X_Q5_0_PASCAL 64
+#define MMQ_Y_Q5_0_PASCAL 64
+#define NWARPS_Q5_0_PASCAL 8
+
+#define MMQ_X_Q5_1_RDNA2 64
+#define MMQ_Y_Q5_1_RDNA2 128
+#define NWARPS_Q5_1_RDNA2 8
+#define MMQ_X_Q5_1_RDNA1 64
+#define MMQ_Y_Q5_1_RDNA1 64
+#define NWARPS_Q5_1_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q5_1_AMPERE 4
+#define MMQ_Y_Q5_1_AMPERE 32
+#define NWARPS_Q5_1_AMPERE 4
+#else
+#define MMQ_X_Q5_1_AMPERE 128
+#define MMQ_Y_Q5_1_AMPERE 64
+#define NWARPS_Q5_1_AMPERE 4
+#endif
+#define MMQ_X_Q5_1_PASCAL 64
+#define MMQ_Y_Q5_1_PASCAL 64
+#define NWARPS_Q5_1_PASCAL 8
+
+#define MMQ_X_Q8_0_RDNA2 64
+#define MMQ_Y_Q8_0_RDNA2 128
+#define NWARPS_Q8_0_RDNA2 8
+#define MMQ_X_Q8_0_RDNA1 64
+#define MMQ_Y_Q8_0_RDNA1 64
+#define NWARPS_Q8_0_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q8_0_AMPERE 4
+#define MMQ_Y_Q8_0_AMPERE 32
+#define NWARPS_Q8_0_AMPERE 4
+#else
+#define MMQ_X_Q8_0_AMPERE 128
+#define MMQ_Y_Q8_0_AMPERE 64
+#define NWARPS_Q8_0_AMPERE 4
+#endif
+#define MMQ_X_Q8_0_PASCAL 64
+#define MMQ_Y_Q8_0_PASCAL 64
+#define NWARPS_Q8_0_PASCAL 8
+
+#define MMQ_X_Q2_K_RDNA2 64
+#define MMQ_Y_Q2_K_RDNA2 128
+#define NWARPS_Q2_K_RDNA2 8
+#define MMQ_X_Q2_K_RDNA1 128
+#define MMQ_Y_Q2_K_RDNA1 32
+#define NWARPS_Q2_K_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q2_K_AMPERE 4
+#define MMQ_Y_Q2_K_AMPERE 32
+#define NWARPS_Q2_K_AMPERE 4
+#else
+#define MMQ_X_Q2_K_AMPERE 64
+#define MMQ_Y_Q2_K_AMPERE 128
+#define NWARPS_Q2_K_AMPERE 4
+#endif
+#define MMQ_X_Q2_K_PASCAL 64
+#define MMQ_Y_Q2_K_PASCAL 64
+#define NWARPS_Q2_K_PASCAL 8
+
+#define MMQ_X_Q3_K_RDNA2 128
+#define MMQ_Y_Q3_K_RDNA2 64
+#define NWARPS_Q3_K_RDNA2 8
+#define MMQ_X_Q3_K_RDNA1 32
+#define MMQ_Y_Q3_K_RDNA1 128
+#define NWARPS_Q3_K_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q3_K_AMPERE 4
+#define MMQ_Y_Q3_K_AMPERE 32
+#define NWARPS_Q3_K_AMPERE 4
+#else
+#define MMQ_X_Q3_K_AMPERE 128
+#define MMQ_Y_Q3_K_AMPERE 128
+#define NWARPS_Q3_K_AMPERE 4
+#endif
+#define MMQ_X_Q3_K_PASCAL 64
+#define MMQ_Y_Q3_K_PASCAL 64
+#define NWARPS_Q3_K_PASCAL 8
+
+#define MMQ_X_Q4_K_RDNA2 64
+#define MMQ_Y_Q4_K_RDNA2 128
+#define NWARPS_Q4_K_RDNA2 8
+#define MMQ_X_Q4_K_RDNA1 32
+#define MMQ_Y_Q4_K_RDNA1 64
+#define NWARPS_Q4_K_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q4_K_AMPERE 4
+#define MMQ_Y_Q4_K_AMPERE 32
+#define NWARPS_Q4_K_AMPERE 4
+#else
+#define MMQ_X_Q4_K_AMPERE 64
+#define MMQ_Y_Q4_K_AMPERE 128
+#define NWARPS_Q4_K_AMPERE 4
+#endif
+#define MMQ_X_Q4_K_PASCAL 64
+#define MMQ_Y_Q4_K_PASCAL 64
+#define NWARPS_Q4_K_PASCAL 8
+
+#define MMQ_X_Q5_K_RDNA2 64
+#define MMQ_Y_Q5_K_RDNA2 128
+#define NWARPS_Q5_K_RDNA2 8
+#define MMQ_X_Q5_K_RDNA1 32
+#define MMQ_Y_Q5_K_RDNA1 64
+#define NWARPS_Q5_K_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q5_K_AMPERE 4
+#define MMQ_Y_Q5_K_AMPERE 32
+#define NWARPS_Q5_K_AMPERE 4
+#else
+#define MMQ_X_Q5_K_AMPERE 64
+#define MMQ_Y_Q5_K_AMPERE 128
+#define NWARPS_Q5_K_AMPERE 4
+#endif
+#define MMQ_X_Q5_K_PASCAL 64
+#define MMQ_Y_Q5_K_PASCAL 64
+#define NWARPS_Q5_K_PASCAL 8
+
+#define MMQ_X_Q6_K_RDNA2 64
+#define MMQ_Y_Q6_K_RDNA2 128
+#define NWARPS_Q6_K_RDNA2 8
+#define MMQ_X_Q6_K_RDNA1 32
+#define MMQ_Y_Q6_K_RDNA1 64
+#define NWARPS_Q6_K_RDNA1 8
+#if defined(CUDA_USE_TENSOR_CORES)
+#define MMQ_X_Q6_K_AMPERE 4
+#define MMQ_Y_Q6_K_AMPERE 32
+#define NWARPS_Q6_K_AMPERE 4
+#else
+#define MMQ_X_Q6_K_AMPERE 64
+#define MMQ_Y_Q6_K_AMPERE 64
+#define NWARPS_Q6_K_AMPERE 4
+#endif
+#define MMQ_X_Q6_K_PASCAL 64
+#define MMQ_Y_Q6_K_PASCAL 64
+#define NWARPS_Q6_K_PASCAL 8
+
+
+// QK = number of values after dequantization
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QK4_0 32
+#define QR4_0 2
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+typedef struct {
+ half d; // delta
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+#define QR4_1 2
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+typedef struct {
+ half2 dm; // dm.x = delta, dm.y = min
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+#define QR5_0 2
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+typedef struct {
+ half d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+#define QR5_1 2
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+typedef struct {
+ half2 dm; // dm.x = delta, dm.y = min
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+#define QR8_0 1
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+typedef struct {
+ half d; // delta
+ int8_t qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+#define QR8_1 1
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+typedef struct {
+ half2 ds; // ds.x = delta, ds.y = sum
+ int8_t qs[QK8_0]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
+
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+
+#define QR2_K 4
+#define QI2_K (QK_K / (4*QR2_K))
+typedef struct {
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+ uint8_t qs[QK_K/4]; // quants
+ half2 dm; // super-block scale for quantized scales/mins
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+#define QR3_K 4
+#define QI3_K (QK_K / (4*QR3_K))
+typedef struct {
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+#ifdef GGML_QKK_64
+ uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
+} block_q3_K;
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
+
+#define QR4_K 2
+#define QI4_K (QK_K / (4*QR4_K))
+#ifdef GGML_QKK_64
+typedef struct {
+ half dm[2]; // super-block scales/mins
+ uint8_t scales[2]; // 4-bit block scales/mins
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
+typedef struct {
+ half2 dm; // super-block scale for quantized scales/mins
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
+
+#define QR5_K 2
+#define QI5_K (QK_K / (4*QR5_K))
+#ifdef GGML_QKK_64
+typedef struct {
+ half d; // super-block scale
+ int8_t scales[QK_K/16]; // block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+ half2 dm; // super-block scale for quantized scales/mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
+
+#define QR6_K 2
+#define QI6_K (QK_K / (4*QR6_K))
+typedef struct {
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
+ int8_t scales[QK_K/16]; // scales
+ half d; // delta
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
+
+// In llama.cpp this is only used for intermediate quantization and dot products
+typedef struct {
+ float d; // delta
+ int8_t qs[QK_K]; // quants
+ int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ 4
+
+template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+ const int * v, const int * u, const float & d4, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 8 from each quant value
+ return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ 4
+
+template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+ const float d4d8 = tmp.x;
+ const float m4s8 = tmp.y;
+#else
+ const float2 dm4f = __half22float2(dm4);
+ const float2 ds8f = __half22float2(ds8);
+ const float d4d8 = dm4f.x * ds8f.x;
+ const float m4s8 = dm4f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+ return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ 4
+
+template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 16 from each quant value
+ return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ 4
+
+template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+ const float d5d8 = tmp.x;
+ const float m5s8 = tmp.y;
+#else
+ const float2 dm5f = __half22float2(dm5);
+ const float2 ds8f = __half22float2(ds8);
+ const float d5d8 = dm5f.x * ds8f.x;
+ const float m5s8 = dm5f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+ return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
+ const int * v, const int * u, const float & d8_0, const float & d8_1) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+ return d8_0*d8_1 * sumi;
+}
+
+template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+ const float d8d8 = tmp.x;
+ const float m8s8 = tmp.y;
+#else
+ const float2 dm8f = __half22float2(dm8);
+ const float2 ds8f = __half22float2(ds8);
+ const float d8d8 = dm8f.x * ds8f.x;
+ const float m8s8 = dm8f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+ return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ 2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+ const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const half2 & dm2, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++i) {
+ const int sc = scales[2*i];
+
+ const int vi = (v >> (2*i)) & 0x03030303;
+
+ sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+ sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+ }
+
+ const float2 dm2f = __half22float2(dm2);
+
+ return dm2f.x*sumf_d - dm2f.y*sumf_m;
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const half2 & dm2, const float & d8) {
+
+ int sumi_d = 0;
+ int sumi_m = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+ int sumi_d_sc = 0;
+
+ const int sc = scales[i0 / (QI8_1/2)];
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_d_sc = ggml_cuda_dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+ sumi_m = ggml_cuda_dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
+ }
+
+ sumi_d += sumi_d_sc * (sc & 0xF);
+ }
+
+ const float2 dm2f = __half22float2(dm2);
+
+ return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ 2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ const int isc = scale_offset + 2*i;
+
+ const int isc_low = isc % (QK_K/32);
+ const int sc_shift_low = 4 * (isc / (QK_K/32));
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+ const int isc_high = isc % (QK_K/64);
+ const int sc_shift_high = 2 * (isc / (QK_K/64));
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+ const int sc = (sc_low | sc_high) - 32;
+
+ const int vil = (vl >> (2*i)) & 0x03030303;
+
+ const int vih = ((vh >> i) << 2) & 0x04040404;
+
+ const int vi = __vsubss4(vil, vih);
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d3 * sumf;
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d3, const float & d8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+ int sumi_sc = 0;
+
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+ }
+
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+ }
+
+ return d3*d8 * sumi;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K; ++i) {
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+ const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+ const int v0i = vl0i | vh0i;
+ const int v1i = vl1i | vh1i;
+
+ const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]);
+
+ }
+
+ const float2 dm5f = __half22float2(dm5);
+
+ return dm5f.x*sumf_d - dm5f.y*sumf_m;
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ const int sc = scales[4*i];
+
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+ const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d*sumf;
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+ const float & d6, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+ int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+ for (int i = i0; i < i0 + 2; ++i) {
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+ }
+
+ sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+ }
+
+ return d6 * sumf_d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+ int v[VDR_Q4_0_Q8_1_MMVQ];
+ int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+ }
+
+ return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds);
+}
+
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+ int v[VDR_Q4_1_Q8_1_MMVQ];
+ int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+ }
+
+ return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+ int vl[VDR_Q5_0_Q8_1_MMVQ];
+ int vh[VDR_Q5_0_Q8_1_MMVQ];
+ int u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i);
+ vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+ }
+
+ return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+ int vl[VDR_Q5_1_Q8_1_MMVQ];
+ int vh[VDR_Q5_1_Q8_1_MMVQ];
+ int u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+ vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+ }
+
+ return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+ int v[VDR_Q8_0_Q8_1_MMVQ];
+ int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+ u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ }
+
+ return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds));
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const uint8_t * scales = bq2_K->scales + scale_offset;
+
+ const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+ int u[QR2_K];
+ float d8[QR2_K];
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++ i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const float d = bq3_K->d;
+
+ const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+#ifndef GGML_QKK_64
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ v[0] = q4[0];
+ v[1] = q4[4];
+
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int i = 0; i < QR4_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+
+#else
+
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ const uint16_t * a = (const uint16_t *)bq4_K->scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ const float dall = bq4_K->dm[0];
+ const float dmin = bq4_K->dm[1];
+
+ const float d8_1 = __low2float(bq8_1[0].ds);
+ const float d8_2 = __low2float(bq8_1[1].ds);
+
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+ const int * q4 = (const int *)bq4_K->qs + (iqs/2);
+ const int v1 = q4[0];
+ const int v2 = q4[4];
+
+ const int dot1 = ggml_cuda_dp4a(ui2, v2 & 0x0f0f0f0f, ggml_cuda_dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+ const int dot2 = ggml_cuda_dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, ggml_cuda_dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+ const int dot3 = ggml_cuda_dp4a(0x01010101, ui2, ggml_cuda_dp4a(0x01010101, ui1, 0));
+ const int dot4 = ggml_cuda_dp4a(0x01010101, ui4, ggml_cuda_dp4a(0x01010101, ui3, 0));
+
+ sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+ sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+ return dall * sumf_d - dmin * sumf_m;
+#endif
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+#ifndef GGML_QKK_64
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+ int vl[2];
+ int vh[2];
+ int u[2*QR5_K];
+ float d8[QR5_K];
+
+ const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+ vl[0] = ql[0];
+ vl[1] = ql[4];
+
+ vh[0] = qh[0] >> bq8_offset;
+ vh[1] = qh[4] >> bq8_offset;
+
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+
+#else
+
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+ const int8_t * s = bq5_K->scales;
+
+ const float d = bq5_K->d;
+
+ const float d8_1 = __low2half(bq8_1[0].ds);
+ const float d8_2 = __low2half(bq8_1[1].ds);
+
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+ const int * ql = (const int *)bq5_K->qs + (iqs/2);
+ const int vl1 = ql[0];
+ const int vl2 = ql[4];
+
+ const int step = 4 * (iqs/2); // 0, 4, 8, 12
+ const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
+ const int in = step%8; // 0, 4, 0, 4
+ const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+ const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+ const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+ const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+ const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+ const float sumf_d = d8_1 * (ggml_cuda_dp4a(ui1, v1, 0) * s[0] + ggml_cuda_dp4a(ui2, v2, 0) * s[1])
+ + d8_2 * (ggml_cuda_dp4a(ui3, v3, 0) * s[2] + ggml_cuda_dp4a(ui4, v4, 0) * s[3]);
+
+ return d * sumf_d;
+#endif
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+ const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+ const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+ const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+ const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+ const int8_t * scales = bq6_K->scales + scale_offset;
+
+ int u[QR6_K];
+ float d8[QR6_K];
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
+ }
+
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
+ const int ix = blockDim.x*blockIdx.x + threadIdx.x;
+ if (ix >= kx_padded) {
+ return;
+ }
+ const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i_padded = iy*kx_padded + ix;
+ block_q8_1 * y = (block_q8_1 *) vy;
+
+ const int ib = i_padded / QK8_1; // block index
+ const int iqs = i_padded % QK8_1; // quant index
+
+ const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
+ float amax = fabsf(xi);
+ float sum = xi;
+
+ amax = warp_reduce_max(amax);
+ sum = warp_reduce_sum(sum);
+
+ const float d = amax / 127;
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+ y[ib].qs[iqs] = q;
+ if (iqs > 0) {
+ return;
+ }
+ reinterpret_cast(y[ib].ds.x) = d;
+ reinterpret_cast(y[ib].ds.y) = sum;
+}
+
+template
+static __device__ __forceinline__ dst_t convert_from_half(half val) {
+ return val;
+}
+
+template<>
+__device__ __forceinline__ nv_bfloat16 convert_from_half(half val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ return __float2bfloat16(__half2float(val));
+#else
+ return __half2float(val);
+#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+}
+
+template<>
+__device__ __forceinline__ float convert_from_half(half val) {
+ return __half2float(val);
+}
+
+template
+inline __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const auto i = 0; //we only need dequant one block in each call
+ const block_q2_K * x = (const block_q2_K *) vx;
+
+ const auto tid = threadIdx.x;
+ const int n = tid/32;
+ const int l = tid - 32*n;
+ const int is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ half dall = __low2half(x[i].dm);
+ half dmin = __high2half(x[i].dm);
+ y[l+ 0] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))));
+ y[l+32] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))));
+ y[l+64] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))));
+ y[l+96] = convert_from_half(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))));
+}
+
+template
+inline __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const auto i = 0;
+ const block_q3_K * x = (const block_q3_K *) vx;
+
+ const auto r = threadIdx.x/4;
+ const int tid = r/2;
+ const int is0 = r%2;
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
+ const int n = tid / 4;
+ const int j = tid - 4*n;
+
+ uint8_t m = 1 << (4*n + j);
+ int is = 8*n + 2*j + is0;
+ int shift = 2*j;
+
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+ half d_all = x[i].d;
+ half dl = __hmul(d_all, __int2half_rn(us - 32));
+
+ dst_t * y = yy + i*QK_K + 128*n + 32*j;
+ const uint8_t * q = x[i].qs + 32*n;
+ const uint8_t * hm = x[i].hmask;
+
+ for (int l = l0; l < l0+4; ++l) {
+ y[l] = convert_from_half(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
+ }
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+ if (j < 4) {
+ d = q[j] & 63; m = q[j + 4] & 63;
+ } else {
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ }
+}
+
+template
+inline __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q4_K * x = (const block_q4_K *) vx;
+
+ const auto i = 0;
+
+ // assume 32 threads
+ const auto tid = threadIdx.x;
+ const int il = tid/8;
+ const int ir = tid%8;
+ const int is = 2*il;
+ const int n = 4;
+
+ dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+ const half dall = __low2half(x[i].dm);
+ const half dmin = __high2half(x[i].dm);
+
+ const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const half d1 = __hmul(dall, __int2half_rn(sc));
+ const half m1 = __hmul(dmin, __int2half_rn(m));
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const half d2 = __hmul(dall, __int2half_rn(sc));
+ const half m2 = __hmul(dmin, __int2half_rn(m));
+ for (int l = 0; l < n; ++l) {
+ y[l + 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
+ y[l +32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
+ }
+}
+
+template
+inline __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q5_K * x = (const block_q5_K *) vx;
+
+ const auto i = 0;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const auto tid = threadIdx.x;
+ const int il = tid/16; // il is in 0...3
+ const int ir = tid%16; // ir is in 0...15
+ const int is = 2*il; // is is in 0...6
+
+ dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+ const half dall = __low2half(x[i].dm);
+ const half dmin = __high2half(x[i].dm);
+
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+ const uint8_t * qh = x[i].qh + 2*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m));
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
+
+ uint8_t hm = 1 << (2*il);
+ y[ 0] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
+ y[ 1] = convert_from_half(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
+ hm <<= 1;
+ y[32] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
+ y[33] = convert_from_half(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
+}
+
+template
+inline __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q6_K * x = (const block_q6_K *) vx;
+
+ const auto i = 0;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const auto tid = threadIdx.x;
+ const int ip = tid/32; // ip is 0 or 1
+ const int il = tid - 32*ip; // 0...32
+ const int is = 8*ip + il/16;
+
+ dst_t * y = yy + i*QK_K + 128*ip + il;
+
+ const half d = x[i].d;
+
+ const uint8_t * ql = x[i].ql + 64*ip + il;
+ const uint8_t qh = x[i].qh[32*ip + il];
+ const int8_t * sc = x[i].scales + is;
+
+ y[ 0] = convert_from_half(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
+ y[32] = convert_from_half(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
+ y[64] = convert_from_half(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
+ y[96] = convert_from_half(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
+}
\ No newline at end of file
diff --git a/candle-kernels/src/moe/moe_gguf.cu b/candle-kernels/src/moe/moe_gguf.cu
new file mode 100644
index 0000000000..92704e6aad
--- /dev/null
+++ b/candle-kernels/src/moe/moe_gguf.cu
@@ -0,0 +1,216 @@
+/**
+ * @brief CUDA kernel for Mixture-of-Experts (MoE) GEMM using GGUF quantized weights.
+ *
+ * This kernel performs a dot-product between quantized input tokens and
+ * quantized expert weight matrices, accumulating into float outputs.
+ * It supports per-token top-k weighting and tiling along the K dimension
+ * for efficient vectorized execution.
+ *
+ * Adapted from: https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_gemm_gguf.cu
+ */
+#include "gguf.cuh"
+#include
+#include
+#include
+#include
+#include
+#include
+constexpr int MATRIX_ROW_PADDING = 512;
+
+constexpr int pad(int size, int padding) {
+ if (padding == 0) return size; // avoid divide-by-zero
+ return ((size + padding - 1) / padding) * padding;
+}
+
+// Optional helper if you want ceil division explicitly
+constexpr int ceil_div(int a, int b) {
+ return (a + b - 1) / b;
+}
+
+namespace vllm_rs {
+
+/*
+* Template Parameters:
+ * @tparam T Type of output elements (float, half, etc.)
+ * @tparam qk Quantization block size for weights (e.g., 32)
+ * @tparam qi Quantization block size for inputs (e.g., 32)
+ * @tparam block_q_t Type of quantized weight block (e.g., block_q8_0)
+ * @tparam vdr Vectorization factor (number of elements per lane)
+ * @tparam vec_dot_q_cuda Function for computing vectorized dot-product between quantized blocks
+ *
+ * Kernel Parameters:
+ * @param all_weights Pointer to all expert weight matrices, [num_experts, N, K] (quantized)
+ * @param all_inputs Pointer to all input tokens, [M_total, K] (quantized)
+ * @param sorted_token_ids Sorted token indices for batch processing
+ * @param expert_ids Expert ID for each token
+ * @param topk_weights Optional top-k MoE weight per token
+ * @param all_outputs Output buffer [M_total, N] (float)
+ * @param num_experts Number of experts
+ * @param topk Top-k experts selected per token
+ * @param size_m Number of tokens processed (M dimension)
+ * @param size_n Output feature dimension (N dimension)
+ * @param size_k Input feature dimension (K dimension)
+ * @param k_padded Padded K dimension for GGUF stride
+*/
+template
+__global__ void moe_gemm_gguf_kernel(
+ const void * __restrict__ all_weights, // [num_experts, N, K] (quantized)
+ const void * __restrict__ all_inputs, // [M_total, K] (quantized, M_total is total tokens)
+ const int32_t* __restrict__ sorted_token_ids,// [M] (M = num tokens processed)
+ const int32_t* __restrict__ expert_ids, // [M]
+ const float* __restrict__ topk_weights, // [M]
+ float * __restrict__ all_outputs, // [M_total, N] (float)
+ int num_experts,
+ int topk,
+ int size_m, int size_n, int size_k, // M, N, K are the logical dims
+ int k_padded // Padded K-dim for GGUF stride
+) {
+ const int laneId = threadIdx.x;
+ const int wrapId = threadIdx.y;
+ const int nWraps = blockDim.y;
+ const int row = blockIdx.x * nWraps + wrapId; // This is the 'n' dimension (output row)
+ const int m_idx = blockIdx.y; // This is the 'm' dimension (token index)
+
+ // This block computes the dot product for `output[token_id][n_row]`
+
+ if (row >= size_n || m_idx >= size_m) {
+ return;
+ }
+
+ // strides
+ const size_t weight_expert_stride_bytes = (size_t)(size_n * size_k) / qk * sizeof(block_q_t);
+ const size_t input_task_stride_bytes = (size_t)k_padded / QK8_1 * sizeof(block_q8_1);
+ const size_t output_task_stride_elems = (size_t)size_n;
+
+ const int token_id = sorted_token_ids[m_idx]; // The *actual* row in input/output tensors
+ const int expert = expert_ids[m_idx];
+
+ // If expert is invalid, this token does not participate.
+ if (expert < 0 || expert >= num_experts) return;
+
+ // Get the scaling factor for this token/expert pair
+ const float scale = (topk_weights) ? topk_weights[token_id] : 1.0f;
+
+ const block_q_t * __restrict__ w_expert =
+ (const block_q_t *)((const char *)all_weights + (size_t)expert * weight_expert_stride_bytes);
+
+ const int input_index = topk_weights ? token_id : (token_id / topk);
+ const block_q8_1 * __restrict__ y_ptr =
+ (const block_q8_1 *)((const char *)all_inputs + (size_t)input_index * input_task_stride_bytes);
+
+ // dot-product tiling along k
+ const int blocks_per_row_x = size_k / qk;
+ const int blocks_per_iter = vdr * WARP_SIZE / qi; // no nwarps factor: one warp per batch item
+
+ extern __shared__ int8_t shared_bytes[];
+ block_q_t* w_shared_row = reinterpret_cast(shared_bytes);
+ for (int i = laneId; i < blocks_per_row_x; i += WARP_SIZE) {
+ w_shared_row[wrapId * blocks_per_row_x + i] = w_expert[row * blocks_per_row_x + i];
+ }
+ __syncthreads();
+
+ // accumulators for rows_per_block rows (usually 1)
+ float acc = 0.0f;
+
+ #pragma unroll
+ for (int kbx = laneId / (qi / vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+ const int kby = kbx * (qk / QK8_1);
+ const int kqs = vdr * (laneId % (qi / vdr));
+ acc += vec_dot_q_cuda(
+ // &w_expert[kbx + row * blocks_per_row_x],
+ &w_shared_row[wrapId * blocks_per_row_x + kbx],
+ &y_ptr[kby],
+ kqs);
+ }
+
+ float v = warp_reduce_sum(acc) * scale;
+ if (laneId == 0) {
+ float * __restrict__ out_ptr =
+ all_outputs + ((size_t)token_id) * output_task_stride_elems;
+ out_ptr[row] = v;
+ }
+}
+
+}
+
+#define LAUNCH_MOE_GGUF(qk, qi, block_q_t, vdr, vec_dot_q_cuda) \
+ const int shared_bytes = size_k / qk * sizeof(block_q_t) * nWraps + 1024;\
+ vllm_rs::moe_gemm_gguf_kernel \
+ <<>>(\
+ weights, y_q8_1,\
+ sorted_token_ids, expert_ids, topk_weights,\
+ outputs,\
+ num_experts, topk,\
+ size_m, size_n, size_k,\
+ kx_padded\
+ );\
+
+
+extern "C" void moe_gemm_gguf(
+ const float* inputs, //must be float
+ const void* weights,
+ const int32_t* sorted_token_ids,
+ const int32_t* expert_ids,
+ const float* topk_weights,
+ float* outputs,
+ int num_experts,
+ int topk,
+ int size_m, // M (num tokens to process)
+ int size_n, // N (output dim)
+ int size_k, // K (input dim)
+ int quant_type, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5,
+ cudaStream_t stream
+) {
+ const int QUANTIZE_BLOCK_SIZE = CUDA_QUANTIZE_BLOCK_SIZE;
+ const int kx_padded = pad(size_k, MATRIX_ROW_PADDING);
+ const int num_blocks = ceil_div(kx_padded, QUANTIZE_BLOCK_SIZE);
+ int m = topk_weights ? size_m : size_m / topk;
+ dim3 grid_dim_quant(num_blocks, m, 1);
+ dim3 block_dim_quant(QUANTIZE_BLOCK_SIZE, 1, 1);
+ int y_size_in_bytes =
+ m * (kx_padded / QK8_1 * sizeof(block_q8_1));
+ void* y_q8_1 = nullptr;
+ cudaMallocAsync(&y_q8_1, y_size_in_bytes, stream);
+ quantize_q8_1<<>>(inputs, y_q8_1, size_k, kx_padded);
+
+ const int nWraps = 4;
+ dim3 grid_dim(ceil_div(size_n, nWraps), size_m, 1);
+ dim3 block_dim(WARP_SIZE, nWraps, 1);
+
+ //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5,
+ switch (quant_type) {
+ case 0: // Q8_0
+ {
+ LAUNCH_MOE_GGUF(QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1);
+ break;
+ }
+ case 1: // Q4K
+ {
+ LAUNCH_MOE_GGUF(QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1);
+ break;
+ }
+ case 2: // Q2_K
+ {
+ LAUNCH_MOE_GGUF(QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1);
+ break;
+ }
+ case 3: // Q3_K
+ {
+ LAUNCH_MOE_GGUF(QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1);
+ break;
+ }
+ case 4: // Q5_K
+ {
+ LAUNCH_MOE_GGUF(QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1);
+ break;
+ }
+ case 5: // Q6K
+ {
+ LAUNCH_MOE_GGUF(QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1);
+ break;
+ }
+ default:
+ break;
+ }
+ cudaFreeAsync(y_q8_1, stream);
+}
\ No newline at end of file
diff --git a/candle-kernels/src/moe/moe_utils.cuh b/candle-kernels/src/moe/moe_utils.cuh
new file mode 100644
index 0000000000..596434088c
--- /dev/null
+++ b/candle-kernels/src/moe/moe_utils.cuh
@@ -0,0 +1,188 @@
+#undef __CUDA_FP8_TYPES_EXIST__
+#include
+#include
+#include
+#include
+#include
+
+/**
+ * @brief Counts the number of tokens assigned to each expert.
+ *
+ * @param expert_ids Device pointer to the sorted expert IDs [size_m].
+ * @param expert_counts Device pointer to the output counts [num_experts]
+ * (must be pre-initialized to zero).
+ * @param size_m Total number of tokens.
+ */
+static __global__ void count_tokens_per_expert_kernel(
+ const int32_t* expert_ids,
+ int32_t* expert_counts,
+ int size_m)
+{
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i < size_m) {
+ int32_t expert_id = expert_ids[i];
+ // expert_id is from a sorted list, so we assume it's valid
+ // (i.e., 0 <= expert_id < num_experts)
+ atomicAdd(&expert_counts[expert_id], 1);
+ }
+}
+
+/**
+ * @brief Calculates expert offsets array on the GPU.
+ *
+ * @param d_expert_ids Device pointer to sorted expert IDs [size_m].
+ * @param size_m Total number of tokens.
+ * @param d_expert_offsets Device pointer for output offsets [num_experts + 1].
+ * @param num_experts Number of experts.
+ * @param stream CUDA stream.
+ */
+static void calculate_expert_offsets(
+ const int32_t* d_expert_ids,
+ int size_m,
+ int32_t* d_expert_counts,
+ int32_t* d_expert_offsets,
+ int num_experts,
+ cudaStream_t stream
+) {
+ // 1. Zero-initialize the counts buffer
+ cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream);
+
+ // 2. Launch kernel to count tokens per expert
+ int threads = 256;
+ int blocks = (size_m + threads - 1) / threads;
+ count_tokens_per_expert_kernel<<>>(
+ d_expert_ids, d_expert_counts, size_m
+ );
+
+ // 3. Perform prefix sum (scan)
+ // We will use inclusive_scan on [counts] and store results in [offsets + 1]
+ // This is a common and efficient pattern.
+
+ // Wrap raw pointers for Thrust
+ thrust::device_ptr d_counts_ptr(d_expert_counts);
+ thrust::device_ptr d_offsets_ptr(d_expert_offsets);
+
+ // Run inclusive scan.
+ // Input: [c0, c1, c2, ...] (size num_experts)
+ // Output: [c0, c0+c1, c0+c1+c2, ...] (stored at offsets[1])
+ thrust::inclusive_scan(
+ thrust::cuda::par.on(stream), // Execute on the specified stream
+ d_counts_ptr, // Input start
+ d_counts_ptr + num_experts, // Input end
+ d_offsets_ptr + 1 // Output start (shifted by 1)
+ );
+
+ // 4. Set the first offset (offsets[0]) to 0
+ // This completes the exclusive scan.
+ cudaMemsetAsync(d_expert_offsets, 0, sizeof(int32_t), stream);
+}
+
+
+// This performs an EXCLUSIVE scan: [c0, c1] -> [0, c0, c0+c1]
+// Assumptions: num_experts <= 1024 (fits in one block)
+static __global__ void expert_prefix_sum_kernel(
+ const int32_t* __restrict__ counts,
+ int32_t* __restrict__ offsets,
+ int num_experts
+) {
+ // Use shared memory for fast scanning
+ // Size needs to be enough for num_experts
+ extern __shared__ int32_t temp_storage[];
+
+ int tid = threadIdx.x;
+
+ // We pad with 0 if tid >= num_experts
+ int val = (tid < num_experts) ? counts[tid] : 0;
+ temp_storage[tid] = val;
+
+ __syncthreads();
+
+ // Hillis-Steele Parallel Scan (Inclusive in shared mem)
+ for (int offset = 1; offset < blockDim.x; offset <<= 1) {
+ int temp_val = 0;
+ if (tid >= offset) {
+ temp_val = temp_storage[tid - offset];
+ }
+ __syncthreads();
+ if (tid >= offset) {
+ temp_storage[tid] += temp_val;
+ }
+ __syncthreads();
+ }
+
+ // The result at temp_storage[i] is the inclusive sum of counts[0..i]
+ // We want offsets[i] = inclusive_sum[i-1]
+ // We want offsets[0] = 0
+
+ if (tid < num_experts) {
+ // Shift right: Offset[i+1] gets the inclusive sum up to i
+ offsets[tid + 1] = temp_storage[tid];
+
+ // Handle the first element separately
+ if (tid == 0) {
+ offsets[0] = 0;
+ }
+ }
+}
+
+static void calculate_expert_offsets_light(
+ const int32_t* d_expert_ids,
+ int size_m,
+ int32_t* d_expert_counts,
+ int32_t* d_expert_offsets,
+ int num_experts,
+ cudaStream_t stream
+) {
+ cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream);
+
+ int threads = 256;
+ int blocks = (size_m + threads - 1) / threads;
+ count_tokens_per_expert_kernel<<>>(
+ d_expert_ids, d_expert_counts, size_m
+ );
+
+ // We launch exactly one block with 'num_experts' threads (or next power of 2)
+ // We need shared memory size = threads * sizeof(int32_t)
+ int scan_threads = num_experts;
+
+ // Round up scan_threads to next power of 2 if needed,
+ // or just use a fixed size like 1024 if num_experts is small enough.
+ if (scan_threads < 32) scan_threads = 32;
+ else if (scan_threads > 1024) {
+ // Error: This custom kernel only supports up to 1024 experts
+ // Handle error or assert here
+ }
+
+ size_t smem_size = scan_threads * sizeof(int32_t);
+
+ expert_prefix_sum_kernel<<<1, scan_threads, smem_size, stream>>>(
+ d_expert_counts,
+ d_expert_offsets,
+ num_experts
+ );
+}
+
+namespace vllm_rs {
+
+inline __device__ uint16_t float_to_half(float f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+#ifndef USE_ROCM
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+#else
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
+#endif
+ return tmp.u16[0];
+}
+
+inline __device__ void from_float(half& dst, float src) {
+ dst = static_cast(float_to_half(src));
+}
+
+inline __device__ void from_float(__nv_bfloat16& dst, float src) {
+ dst = __float2bfloat16(src);
+}
+
+}
\ No newline at end of file
diff --git a/candle-kernels/src/moe/moe_wmma.cu b/candle-kernels/src/moe/moe_wmma.cu
new file mode 100644
index 0000000000..de6a90993b
--- /dev/null
+++ b/candle-kernels/src/moe/moe_wmma.cu
@@ -0,0 +1,283 @@
+/**
+ * @brief WMMA-based grouped MoE GEMM kernel.
+ *
+ * Each block computes a tile of the output corresponding to:
+ * - One expert segment (group of tokens routed to the same expert)
+ * - One N-dimension tile (a sub-block of the expert's output features)
+ *
+ * The kernel loads input activations and expert weights in tiles using shared memory,
+ * performs matrix multiplication using Tensor Cores (WMMA), and accumulates results
+ * into a shared C tile. The final results are written atomically into the global
+ * output buffer to support multi-expert (top-k > 1) routing where tokens appear in
+ * multiple experts’ outputs.
+ *
+ * Adapted from https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_gemm_wmma.cu
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "moe_utils.cuh"
+using namespace nvcuda::wmma;
+
+namespace vllm_rs {
+
+#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
+
+constexpr int WMMA_K = 16;
+using VecT = float4;
+
+// Vectorized load size (float4 = 128 bits = 8 half/bfloat16 values)
+constexpr int VEC_SIZE = 8;
+constexpr int NUM_VECS = 32;
+
+// We use 4 Warps (128 threads) per block
+constexpr int WARPS_PER_BLOCK = 4; // 4 warps
+constexpr int BLOCK_THREADS = 128; // 128 threads
+
+constexpr int M_BLK = 32;
+constexpr int N_BLK = 32;
+constexpr int K_BLK = WMMA_K; // 16
+
+
+/**
+ * @brief WMMA-based grouped MoE GEMM kernel.
+ *
+ * @tparam T Data type: half or nv_bfloat16
+ *
+ * @param input [size_m or size_m/topk, size_k]
+ * @param weights [num_experts, size_n, size_k] compacted expert weights
+ * @param sorted_token_ids [size_m] mapping of per-token row indices (sorted by expert)
+ * @param expert_offsets [num_experts] array of {start, len} tokens indices for each expert
+ * @param topk_weights [size_m] optional per-token scaling weights (nullptr if unused)
+ * @param output [size_m, size_n] global output buffer (must be zero-initialized)
+ * @param num_experts Total number of experts
+ * @param topk Number of experts each token is routed to
+ * @param size_m Number of tokens
+ * @param size_n Output hidden dimension (per expert)
+ * @param size_k Input hidden dimension
+*/
+template
+__global__ void moe_gemm_grouped_kernel(
+ const T* __restrict__ input, // [size_m, size_k]
+ const T* __restrict__ weights, // [num_experts, size_n, size_k]
+ const int32_t* __restrict__ sorted_token_ids, // [size_m]
+ const int32_t* __restrict__ expert_offsets, // [num_experts]
+ const float* __restrict__ topk_weights, // [size_m]
+ T* __restrict__ output, // [size_m, size_n] (Zero-initialized)
+ const int num_experts, const int topk,
+ const int32_t size_m,
+ const int32_t size_n,
+ const int32_t size_k
+) {
+ // Get Segment and N-Tile for this Block
+ const int expert_id = blockIdx.x;
+ const int n_tile_idx = blockIdx.y;
+ if (expert_id < 0 || expert_id >= num_experts) return;
+ const int segment_start = expert_offsets[expert_id];
+ const int segment_end = expert_offsets[expert_id + 1];
+ const int num_rows_in_segment = segment_end - segment_start;
+
+ if (num_rows_in_segment == 0) return;
+
+ const int n_base = n_tile_idx * N_BLK;
+ if (n_base >= size_n) return;
+
+ const T* expert_w = weights + (size_t)expert_id * (size_t)size_n * (size_t)size_k;
+
+ extern __shared__ uint8_t smem_bytes[];
+
+ // A tile: [M_BLK, K_BLK] (row-major)
+ T* A_sh = reinterpret_cast(smem_bytes);
+ // B tile: [N_BLK, K_BLK] (row-major)
+ T* B_sh = reinterpret_cast(A_sh + M_BLK * K_BLK);
+ uint8_t* C_ptr = reinterpret_cast(B_sh + N_BLK * K_BLK);
+
+ // align next pointer to float alignment
+ size_t offset = reinterpret_cast(C_ptr) % alignof(float);
+ if (offset != 0) {
+ C_ptr += (alignof(float) - offset);
+ }
+ float* C_sh = reinterpret_cast(C_ptr); // shared scratch for final per-block tile writes
+
+ const int threadId = threadIdx.x;
+ const int warpId = threadId / 32;
+ const int laneId = threadId % 32;
+ const int warp_m_idx = warpId / WARPS_N;
+ const int warp_n_idx = warpId % WARPS_N;
+
+ const int B_ELEMS_PER_BLOCK = N_BLK * K_BLK;
+ const int VEC_ELEMS_B = B_ELEMS_PER_BLOCK / VEC_SIZE; // 512 / 8 = 64
+ const int A_ELEMS_PER_BLOCK = M_BLK * K_BLK;
+ const int VEC_ELEMS_A = A_ELEMS_PER_BLOCK / VEC_SIZE; // 512 / 8 = 64
+ VecT zero_vec;
+ zero_vec.x = zero_vec.y = zero_vec.z = zero_vec.w = 0.0f;
+
+ for (int m_base = 0; m_base < num_rows_in_segment; m_base += M_BLK) {
+ // We'll accumulate full-K results in per-warp fragments (initialized here)
+ fragment c_frag;
+ fill_fragment(c_frag, 0.0f);
+
+ // For every k_block we will load B_sh and A_sh for this m_base subsequently
+ for (int k_base = 0; k_base < size_k; k_base += K_BLK) {
+ // Load B Tile (Weights) into B_sh
+ for (int i = threadId; i < VEC_ELEMS_B; i += BLOCK_THREADS) {
+ int idx = i * VEC_SIZE; // element index (0..511)
+ int n_local = idx / K_BLK;
+ int k_local = idx % K_BLK;
+
+ int n_global = n_base + n_local;
+ int k_global = k_base + k_local;
+
+ // this should be always satisfied since k dim aligned to 8
+ if (n_global < size_n && k_global < size_k) {
+ *reinterpret_cast(&B_sh[n_local * K_BLK + k_local]) = *reinterpret_cast(
+ &expert_w[(size_t)n_global * size_k + k_global]
+ );
+ } else {
+ *reinterpret_cast(&B_sh[n_local * K_BLK + k_local]) = zero_vec;
+ }
+ }
+
+ // Load A Tile (Inputs) into A_sh for this m_base and this k_base
+ for (int i = threadId; i < VEC_ELEMS_A; i += BLOCK_THREADS) {
+ int idx = i * VEC_SIZE; // element index
+ int m_local = idx / K_BLK;
+ int k_local = idx % K_BLK;
+
+ int m_seg = m_base + m_local; // row index within segment
+ int k_global = k_base + k_local;
+
+ if (m_seg < num_rows_in_segment && k_global < size_k) {
+ int token_pair_index = segment_start + m_seg;
+ int token_index = sorted_token_ids[token_pair_index];
+ int input_index = token_index / (topk_weights? 1: topk);
+ *reinterpret_cast(&A_sh[m_local * K_BLK + k_local]) = *reinterpret_cast(
+ &input[(size_t)input_index * size_k + k_global]
+ );
+ } else {
+ // in case m dim in this segment not aligned to 8
+ *reinterpret_cast(&A_sh[m_local * K_BLK + k_local]) = zero_vec;
+ }
+ }
+
+ __syncthreads();
+
+ // Compute (Warp-level) : update c_frag for this k_block
+ fragment a_frag;
+ fragment b_frag;
+
+ // Point this warp to its tile in shared memory
+ const T* A_sh_ptr = A_sh + (warp_m_idx * WMMA_M * K_BLK);
+ const T* B_sh_ptr = B_sh + (warp_n_idx * WMMA_N * K_BLK);
+
+ load_matrix_sync(a_frag, A_sh_ptr, K_BLK);
+ load_matrix_sync(b_frag, B_sh_ptr, K_BLK);
+
+ // Accumulate into c_frag (which persists across k_base iterations)
+ mma_sync(c_frag, a_frag, b_frag, c_frag);
+ } // end k_base loop (we have a fully-accumulated c_frag for this m_base tile)
+
+ // Store the accumulated c_frag to C_sh (shared) once per warp
+ // Point this warp to its 16x16 tile *within* the 32x32 C_sh
+ float* C_sh_ptr = C_sh + (warp_m_idx * WMMA_M * N_BLK) + (warp_n_idx * WMMA_N);
+ // store the full accumulated 16x16 tile (note ld = N_BLK, result in row-major in C_sh)
+ store_matrix_sync(C_sh_ptr, c_frag, N_BLK, mem_row_major);
+
+ __syncthreads();
+
+ // Cooperative Store from C_sh to Global
+ // 128 threads write [M_BLK, N_BLK] = [32, 32] = 1024 elements
+ const int C_ELEMS_PER_BLOCK = M_BLK * N_BLK;
+ for (int i = threadId; i < C_ELEMS_PER_BLOCK; i += BLOCK_THREADS) {
+ int m_local_c = i / N_BLK; // row in C_sh (0..31)
+ int n_local_c = i % N_BLK; // col in C_sh (0..31)
+
+ int m_seg = m_base + m_local_c; // row index within segment
+ int n_global = n_base + n_local_c; // col index in output
+
+ if (m_seg < num_rows_in_segment && n_global < size_n) {
+ int token_pair_index = segment_start + m_seg;
+ if (token_pair_index < size_m) {
+ int token_index = sorted_token_ids[token_pair_index];
+ float val = C_sh[m_local_c * N_BLK + n_local_c];
+ if (topk_weights) {
+ val *= topk_weights[token_index];
+ }
+ from_float(output[(size_t)token_index * size_n + n_global], val);
+ }
+ }
+ }
+ } // end m_base loop
+}
+
+}
+
+#define LAUNCH_MOE_WMMA(DTYPE, WMMA_M, WMMA_N, WARPS_N)\
+ vllm_rs::moe_gemm_grouped_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids,\
+ expert_offsets,\
+ topk_weights,\
+ reinterpret_cast(output),\
+ num_experts, topk,\
+ size_m, size_n, size_k \
+ );\
+
+extern "C" void moe_gemm_wmma(
+ const void* input, // [size_m, size_k]
+ const void* weights, // [num_experts, size_n, size_k]
+ const int32_t* sorted_token_ids, // [size_m] (Device)
+ const int32_t* expert_ids, // [size_m * topk]
+ const float* topk_weights, // [size_m] (Device, can be nullptr)
+ void* output, // [size_m, size_n]
+ int32_t* expert_counts, // prealloc [num_experts]
+ int32_t* expert_offsets, // prealloc [num_experts + 1]
+ int num_experts,
+ int topk,
+ int size_m,
+ int size_n,
+ int size_k,
+ int data_type, // 0 = half, 1 = bfloat16
+ bool is_prefill,
+ cudaStream_t stream
+) {
+ if (is_prefill) {
+ calculate_expert_offsets(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream);
+ } else {
+ calculate_expert_offsets_light(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream);
+ }
+
+ int grid_n = CEILDIV(size_n, vllm_rs::N_BLK);
+ dim3 grid(num_experts, grid_n, 1);
+ dim3 block(vllm_rs::BLOCK_THREADS, 1, 1);
+
+ // Shared memory: A_sh[M_BLK, K_BLK] + B_sh[N_BLK, K_BLK]
+ size_t A_sh_bytes = vllm_rs::M_BLK * vllm_rs::K_BLK * 2; // (32*16 * 2) = 1024
+ size_t B_sh_bytes = vllm_rs::N_BLK * vllm_rs::K_BLK * 2; // (32*16 * 2) = 1024
+ size_t C_sh_bytes = vllm_rs::M_BLK * vllm_rs::N_BLK * sizeof(float);
+ size_t AB_bytes = A_sh_bytes + B_sh_bytes;
+ size_t pad = (16 - (AB_bytes % 16)) % 16;
+ size_t smem_bytes = AB_bytes + pad + C_sh_bytes; // ~6KB total needed
+
+ if (data_type == 0) { // half
+ if (is_prefill) {
+ LAUNCH_MOE_WMMA(half, 16, 16, 2)
+ } else {
+ // we use smaller M_tile and larger N_tile for decoding
+ LAUNCH_MOE_WMMA(half, 8, 32, 1)
+ }
+ } else if (data_type == 1) { // bfloat16
+ if (is_prefill) {
+ LAUNCH_MOE_WMMA(nv_bfloat16, 16, 16, 2)
+ } else {
+ LAUNCH_MOE_WMMA(nv_bfloat16, 8, 32, 1)
+ }
+ }
+}
\ No newline at end of file
diff --git a/candle-kernels/src/moe/moe_wmma_gguf.cu b/candle-kernels/src/moe/moe_wmma_gguf.cu
new file mode 100644
index 0000000000..0d3701ee82
--- /dev/null
+++ b/candle-kernels/src/moe/moe_wmma_gguf.cu
@@ -0,0 +1,422 @@
+/**
+ * @brief CUDA kernel for Mixture-of-Experts (MoE) GEMM with GGUF quantized weights and Tensor Core.
+ *
+ * This kernel performs batched GEMM where the weight matrix is stored in GGUF
+ * quantized format (uint8_t blocks). It supports top-k expert selection and
+ * segmented expert layouts. Uses shared memory tiles and WMMA (tensor cores)
+ * for efficient computation.
+ *
+ * Adapted from: https://github.com/guoqingbao/attention.rs/tree/main/src/kernels/src/moe_wmma_gguf.cu
+ */
+#include "gguf.cuh"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "moe_utils.cuh"
+using namespace nvcuda::wmma;
+
+// Constants from original kernel
+constexpr int WMMA_M = 16;
+constexpr int WMMA_N = 16;
+constexpr int WMMA_K = 16; // This is fixed by the hardware instruction
+using VecT = float4;
+
+constexpr int VEC_SIZE = 8;
+constexpr int WARPS_M = 2;
+constexpr int WARPS_N = 2;
+constexpr int WARPS_PER_BLOCK = WARPS_M * WARPS_N; // 4 warps
+
+constexpr int M_BLK = WARPS_M * WMMA_M; // 32
+constexpr int N_BLK = WARPS_N * WMMA_N; // 32
+
+// Helper for ceiling division
+#define CEILDIV(A, B) (((A) + (B)-1) / (B))
+
+// --- GGUF Dequantization Function (Warp-level) ---
+/**
+ * @brief Dequantizes a single GGUF block using one warp (32 threads).
+ *
+ * @tparam T Output type (half or nv_bfloat16)
+ * @param dequant_out Pointer to output in shared mem [qk]
+ * @param quant_in Pointer to input GGUF block in shared mem
+ * @param type GGUF type
+ * @param qk Quantization group size (32 or 256)
+ * @param laneId threadIdx.x % 32
+ */
+template
+__forceinline__ __device__ void dequantize_block_warp(
+ T* dequant_out,
+ const uint8_t* quant_in,
+ int gguf_dtype //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5,
+) {
+ using namespace nvcuda;
+ switch (gguf_dtype) {
+ case 0: { // qk = 32, q8_0
+ // Block: half d (2B), int8_t qs[32] (32B)
+ int laneId = threadIdx.x;
+ const half* d_ptr = (const half*)quant_in;
+ const int8_t* qs = (const int8_t*)(quant_in + 2);
+
+ // Lane 0 loads scale and broadcasts to all other lanes
+ half d_val = (laneId == 0) ? *d_ptr : (half)0.0f;
+ d_val = __shfl_sync(0xFFFFFFFF, d_val, 0);
+ float d_f = __half2float(d_val);
+
+ // 32 lanes dequantize 32 values
+ if (laneId < QK8_0) { // qk should be 32
+ dequant_out[laneId] = T( (float)qs[laneId] * d_f );
+ }
+ break;
+ }
+ case 1: { // q4k, 32 lanes
+ dequantize_block_q4_K(quant_in, dequant_out);
+ break;
+ }
+ case 2: { // q2k, 64 lanes
+ dequantize_block_q2_K(quant_in, dequant_out);
+ break;
+ }
+ case 3: { // q3k, 64 lanes
+ dequantize_block_q3_K(quant_in, dequant_out);
+ break;
+ }
+ case 4: { // q5k, 64 lanes
+ dequantize_block_q5_K(quant_in, dequant_out);
+ break;
+ }
+ case 5: { // q6k, 64 lanes
+ dequantize_block_q6_K(quant_in, dequant_out);
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+/*
+* Template Parameters:
+ * @tparam T Type of input/output (float, half, etc.)
+ * @tparam qk Quantization block size (e.g., 32)
+ * @tparam block_q_t Type representing a single GGUF block (e.g., block_q8_0)
+ * @tparam wrap_size Warp size used for thread tiling (usually 32)
+ *
+ * Kernel Parameters:
+ * @param input Input matrix [size_m, size_k]
+ * @param weights GGUF quantized weights buffer (uint8_t blocks)
+ * @param sorted_token_ids Array of sorted token indices for MoE routing
+ * @param expert_offsets [num_experts] array of {start, len} tokens indices for each expert
+ * @param topk_weights Top-k MoE weights per token (optional)
+ * @param output Output matrix [size_m, size_n]
+ * @param num_experts Number of experts in the MoE
+ * @param topk Number of top experts selected per token
+ * @param size_m Number of input rows / tokens
+ * @param size_n Output feature dimension
+ * @param size_k Input feature dimension
+ * @param gguf_dtype GGUF quantization type ID (e.g., Q8_0)
+*/
+template
+__global__ void moe_gemm_gguf_prefill_kernel(
+ const T* __restrict__ input,
+ const uint8_t* __restrict__ weights, // Now uint8_t*
+ const int32_t* __restrict__ sorted_token_ids,
+ const int32_t* __restrict__ expert_offsets,
+ const float* __restrict__ topk_weights,
+ float* __restrict__ output,
+ const int num_experts, const int topk,
+ const int32_t size_m,
+ const int32_t size_n,
+ const int32_t size_k,
+ const int gguf_dtype
+) {
+ const int expert_id = blockIdx.x;
+ const int n_tile_idx = blockIdx.y;
+
+ if (expert_id < 0 || expert_id >= num_experts) return;
+ const int segment_start = expert_offsets[expert_id];
+ const int segment_end = expert_offsets[expert_id + 1];
+ const int num_rows_in_segment = segment_end - segment_start;
+
+ if (num_rows_in_segment == 0) return;
+ constexpr int BLOCK_THREADS = WARPS_PER_BLOCK * wrap_size; // 128 threads
+
+ const int n_base = n_tile_idx * N_BLK;
+ if (n_base >= size_n) return;
+
+ const size_t block_size_bytes = sizeof(block_q_t);
+ const size_t expert_w_row_stride_bytes = (size_k / qk) * block_size_bytes;
+ const uint8_t* expert_w = weights + (size_t)expert_id * size_n * expert_w_row_stride_bytes;
+
+ extern __shared__ uint8_t smem_bytes[];
+
+ // 1. A tile: [M_BLK, qk] (dequantized)
+ T* A_sh = reinterpret_cast(smem_bytes);
+ size_t A_sh_bytes = (size_t)M_BLK * qk * sizeof(T);
+
+ // 2. B tile: [N_BLK, qk] (dequantized)
+ uint8_t* B_sh_ptr = smem_bytes + A_sh_bytes;
+ size_t B_sh_bytes = (size_t)N_BLK * qk * sizeof(T);
+
+ // 3. B quantized tile: [N_BLK * block_size_bytes] (raw GGUF)
+ uint8_t* B_quant_sh_ptr = B_sh_ptr + B_sh_bytes;
+ size_t B_quant_sh_bytes = (size_t)N_BLK * block_size_bytes;
+
+ // 4. C tile: [M_BLK, N_BLK] (float accumulator)
+ uint8_t* C_sh_ptr = B_quant_sh_ptr + B_quant_sh_bytes;
+ size_t C_sh_offset = reinterpret_cast(C_sh_ptr) % alignof(float);
+ if (C_sh_offset != 0) C_sh_ptr += (alignof(float) - C_sh_offset);
+
+ // Final aligned shared memory pointers
+ T* B_sh = reinterpret_cast(B_sh_ptr);
+ uint8_t* B_quant_sh = reinterpret_cast(B_quant_sh_ptr);
+ float* C_sh = reinterpret_cast(C_sh_ptr);
+
+ const int laneId = threadIdx.x;
+ const int warpId = threadIdx.y;
+ const int threadId = warpId * wrap_size + laneId;
+ const int warp_m_idx = warpId / WARPS_N;
+ const int warp_n_idx = warpId % WARPS_N;
+
+ const size_t A_ELEMS_PER_BLOCK = (size_t)M_BLK * qk;
+ const size_t VEC_ELEMS_A = A_ELEMS_PER_BLOCK / VEC_SIZE;
+ VecT zero_vec;
+ zero_vec.x = zero_vec.y = zero_vec.z = zero_vec.w = 0.0f;
+
+ for (int m_base = 0; m_base < num_rows_in_segment; m_base += M_BLK) {
+
+ // Per-warp accumulator fragment
+ fragment c_frag;
+ fill_fragment(c_frag, 0.0f);
+
+ // K-Loop: Strides by GGUF block size `qk`
+ for (int k_base = 0; k_base < size_k; k_base += qk) {
+
+ // Load A Tile (Inputs) into A_sh
+ #pragma unroll
+ for (size_t i = threadId; i < VEC_ELEMS_A; i += BLOCK_THREADS) {
+ size_t idx = i * VEC_SIZE; // element index
+ size_t m_local = idx / qk;
+ size_t k_local = idx % qk;
+
+ int m_seg = m_base + m_local;
+ int k_global = k_base + k_local;
+
+ if (m_seg < num_rows_in_segment && k_global < size_k) {
+ int token_pair_index = segment_start + m_seg;
+ int token_index = sorted_token_ids[token_pair_index];
+ int input_index = token_index / (topk_weights? 1: topk);
+ *reinterpret_cast(&A_sh[m_local * qk + k_local]) = *reinterpret_cast(
+ &input[(size_t)input_index * size_k + k_global]
+ );
+ } else {
+ *reinterpret_cast(&A_sh[m_local * qk + k_local]) = zero_vec;
+ }
+ }
+
+ // Load B Tile (Quantized) into B_quant_sh
+ const size_t k_base_offset_bytes = (k_base / qk) * block_size_bytes;
+ constexpr int ROWS_PER_WARP = N_BLK / WARPS_PER_BLOCK;
+
+ #pragma unroll
+ for (int row = 0; row < ROWS_PER_WARP; ++row) {
+ int n_local = warpId * ROWS_PER_WARP + row;
+ int n_global = n_base + n_local;
+ if (n_local < N_BLK && n_global < size_n) {
+ block_q_t* dest_ptr = reinterpret_cast(B_quant_sh + n_local * block_size_bytes);
+ const block_q_t* src_ptr = reinterpret_cast(expert_w + (size_t)n_global * expert_w_row_stride_bytes + k_base_offset_bytes);
+ *dest_ptr = *src_ptr;
+ }
+ }
+
+ __syncthreads();
+
+ // Dequantize B from B_quant_sh to B_sh
+ #pragma unroll
+ for (int row = 0; row < ROWS_PER_WARP; ++row) {
+ int n_local = warpId * ROWS_PER_WARP + row;
+ int n_global = n_base + n_local;
+ if (n_local < N_BLK && n_global < size_n) {
+ const uint8_t* quant_ptr = B_quant_sh + n_local * block_size_bytes;
+ T* dequant_ptr = B_sh + n_local * qk; // Stride by qk
+ // Dequantize one block using this warp
+ dequantize_block_warp(dequant_ptr, quant_ptr, gguf_dtype);
+ }
+ }
+
+ __syncthreads();
+
+ // Inner WMMA Loop
+ // A_sh and B_sh are now dequantized and in shared mem
+ // We loop over the K-dim (now `qk`) using the hardware `WMMA_K`
+ #pragma unroll
+ for (int k_tile = 0; k_tile < qk; k_tile += WMMA_K) {
+ fragment a_frag;
+ fragment b_frag;
+
+ // Point to the correct 16x16 tile inside the [M_BLK, qk] / [N_BLK, qk] buffers
+ const T* A_sh_ptr = A_sh + (warp_m_idx * WMMA_M * qk) + k_tile;
+ const T* B_sh_ptr = B_sh + (warp_n_idx * WMMA_N * qk) + k_tile;
+
+ load_matrix_sync(a_frag, A_sh_ptr, qk); // Stride is qk
+ load_matrix_sync(b_frag, B_sh_ptr, qk); // Stride is qk
+
+ mma_sync(c_frag, a_frag, b_frag, c_frag);
+ }
+ } // end k_base loop
+
+ // Store C_frag to C_sh
+ float* C_sh_ptr_warp = C_sh + (warp_m_idx * WMMA_M * N_BLK) + (warp_n_idx * WMMA_N);
+ store_matrix_sync(C_sh_ptr_warp, c_frag, N_BLK, mem_row_major);
+ __syncthreads();
+
+ // Cooperative Store to Global
+ const int C_ELEMS_PER_BLOCK = M_BLK * N_BLK;
+ #pragma unroll
+ for (int i = threadId; i < C_ELEMS_PER_BLOCK; i += BLOCK_THREADS) {
+ int m_local_c = i / N_BLK;
+ int n_local_c = i % N_BLK;
+ int m_seg = m_base + m_local_c;
+ int n_global = n_base + n_local_c;
+
+ if (m_seg < num_rows_in_segment && n_global < size_n) {
+ int token_pair_index = segment_start + m_seg;
+ if (token_pair_index < size_m) {
+ int token_index = sorted_token_ids[token_pair_index];
+ float val = C_sh[m_local_c * N_BLK + n_local_c];
+ if (topk_weights) {
+ val *= topk_weights[token_index];
+ }
+ output[(size_t)token_index * size_n + n_global] = val;
+ }
+ }
+ }
+ } // end m_base loop
+}
+
+#define LAUNCH_MOE_GGUF_PREFILL(DTYPE) \
+ if (gguf_type == 0) {\
+ dim3 block(32, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ } else if (gguf_type == 1) {\
+ dim3 block(32, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ } else if (gguf_type == 2) {\
+ dim3 block(64, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ } else if (gguf_type == 3) {\
+ dim3 block(64, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ } else if (gguf_type == 4) { \
+ dim3 block(64, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ } else if (gguf_type == 5) { \
+ dim3 block(64, WARPS_PER_BLOCK, 1);\
+ moe_gemm_gguf_prefill_kernel<<>>(\
+ reinterpret_cast(input),\
+ reinterpret_cast(weights),\
+ sorted_token_ids, expert_offsets, topk_weights,\
+ output, num_experts, topk, size_m, size_n, size_k, gguf_type\
+ );\
+ }
+
+
+extern "C" void moe_gemm_gguf_prefill(
+ const void* input,
+ const uint8_t* weights,
+ const int32_t* sorted_token_ids,
+ const int32_t* expert_ids,
+ const float* topk_weights,
+ float* output,
+ int num_experts,
+ int topk,
+ int size_m,
+ int size_n,
+ int size_k,
+ int input_dtype, // 0 = half, 1 = bfloat16
+ int gguf_type, //Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5,
+ cudaStream_t stream
+) {
+ int32_t* expert_counts;
+ cudaMallocAsync(&expert_counts, num_experts * sizeof(int32_t), stream);
+
+ int32_t* expert_offsets;
+ cudaMallocAsync(&expert_offsets, (num_experts + 1) * sizeof(int32_t), stream);
+ calculate_expert_offsets(expert_ids, size_m, expert_counts, expert_offsets, num_experts, stream);
+
+ int grid_n = CEILDIV(size_n, N_BLK);
+ dim3 grid(num_experts, grid_n, 1);
+
+ size_t qk = QK_K;
+ size_t block_size_bytes = sizeof(block_q6_K);
+ if (gguf_type == 0) { //Q8_0: 0,
+ block_size_bytes = sizeof(block_q8_0);
+ qk = QK8_0;
+ } else if (gguf_type == 1) {// Q4K: 1,
+ block_size_bytes = sizeof(block_q4_K);
+ } else if (gguf_type == 2) {// Q2K: 2,
+ block_size_bytes = sizeof(block_q2_K);
+ } else if (gguf_type == 3) {//Q3K: 3,
+ block_size_bytes = sizeof(block_q3_K);
+ } else if (gguf_type == 4) {//Q5K: 4,
+ block_size_bytes = sizeof(block_q5_K);
+ }
+
+ // 1. A tile: [M_BLK, qk] (dequantized)
+ size_t A_sh_bytes = (size_t)M_BLK * qk * 2; // 2 for half/bfloat16
+
+ // 2. B tile: [N_BLK, qk] (dequantized)
+ size_t B_sh_bytes = (size_t)N_BLK * qk * 2;
+
+ // 3. B quantized tile: [N_BLK * block_size_bytes]
+ size_t B_quant_sh_bytes = (size_t)N_BLK * block_size_bytes;
+
+ // 4. C tile: [M_BLK, N_BLK] (float accumulator)
+ size_t C_sh_bytes = (size_t)M_BLK * N_BLK * sizeof(float);
+
+ // Add up, with padding for C
+ size_t smem_bytes = A_sh_bytes + B_sh_bytes + B_quant_sh_bytes;
+ size_t C_sh_offset = smem_bytes % alignof(float);
+ if (C_sh_offset != 0) smem_bytes += (alignof(float) - C_sh_offset);
+ smem_bytes += C_sh_bytes;
+
+ if (input_dtype == 0) {
+ LAUNCH_MOE_GGUF_PREFILL(half);
+ } else {
+#ifndef NO_BF16_KERNEL
+ LAUNCH_MOE_GGUF_PREFILL(nv_bfloat16);
+#endif
+ }
+ cudaFreeAsync(expert_counts, stream);
+ cudaFreeAsync(expert_offsets, stream);
+}
diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs
index 2bff47cb12..fcc83b8e3f 100644
--- a/candle-nn/benches/benchmarks/mod.rs
+++ b/candle-nn/benches/benchmarks/mod.rs
@@ -21,13 +21,13 @@ impl BenchDevice for Device {
return Ok(device.synchronize()?);
}
#[cfg(not(feature = "cuda"))]
- panic!("Cuda device without cuda feature enabled: {:?}", device)
+ panic!("Cuda device without cuda feature enabled: {device:?}")
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return device.wait_until_completed();
#[cfg(not(feature = "metal"))]
- panic!("Metal device without metal feature enabled: {:?}", device)
+ panic!("Metal device without metal feature enabled: {device:?}")
}
}
}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index c7a76fbd7a..febd73a2d6 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -28,6 +28,7 @@ pub mod kv_cache;
pub mod layer_norm;
pub mod linear;
pub mod loss;
+pub mod moe;
pub mod ops;
pub mod optim;
pub mod rnn;
diff --git a/candle-nn/src/moe.rs b/candle-nn/src/moe.rs
new file mode 100644
index 0000000000..2f2bd9c2db
--- /dev/null
+++ b/candle-nn/src/moe.rs
@@ -0,0 +1,350 @@
+// Adapted from https://github.com/guoqingbao/attention.rs/blob/main/src/moe.rs
+#[cfg(feature = "cuda")]
+use candle::cuda_backend::kernels::ffi;
+#[allow(unused_imports)]
+use candle::quantized::{self, QTensor};
+use candle::{Result, Tensor};
+
+#[cfg(feature = "cuda")]
+pub fn moe_gemm(
+ input: &Tensor,
+ weights: &Tensor,
+ topk_weights: &Option,
+ sorted_token_ids: &Tensor,
+ experts_ids: &Tensor,
+ topk: usize,
+ is_prefill: bool,
+) -> Result {
+ use candle::cuda_backend::cudarc::driver::DevicePtr;
+ use candle::DType;
+ use half::{bf16, f16};
+
+ fn cuda_fwd<
+ T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
+ >(
+ input: &Tensor,
+ weights: &Tensor,
+ topk_weights: &Option,
+ sorted_token_ids: &Tensor,
+ experts_ids: &Tensor,
+ topk: usize,
+ is_prefill: bool,
+ ) -> Result {
+ let (mut size_m, size_k1) = input.dims2()?;
+ if topk_weights.is_none() {
+ size_m *= topk;
+ }
+ let (num_experts, size_n, size_k) = weights.dims3()?;
+ assert!(
+ size_k == size_k1,
+ "input {:?} and weight {:?} last dim mismatch!",
+ size_k1,
+ size_k
+ );
+ let dev = input.device().as_cuda_device()?;
+ let data_type = match input.dtype() {
+ DType::F16 => 0,
+ DType::BF16 => 1,
+ _ => {
+ candle::bail!("moe_gemm_wmma only accepts f16/bf16 inputs")
+ }
+ };
+
+ let (input, _) = input.storage_and_layout();
+ let input = match &*input {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("input must be a cuda tensor"),
+ };
+
+ let (weights, _) = weights.storage_and_layout();
+ let weights = match &*weights {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("weight must be a cuda tensor"),
+ };
+
+ let (sorted_token_ids, _) = sorted_token_ids.storage_and_layout();
+ let sorted_token_ids = match &*sorted_token_ids {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("sorted_token_ids must be a cuda tensor"),
+ };
+
+ let (experts_ids, _) = experts_ids.storage_and_layout();
+ let experts_ids = match &*experts_ids {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("experts_ids must be a cuda tensor"),
+ };
+
+ let topk_weights_ptr = if let Some(topk_weights) = &topk_weights {
+ let (topk_weights, _) = topk_weights.storage_and_layout();
+ let topk_weights = match &*topk_weights {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("topk_weights must be a cuda tensor"),
+ };
+ let weights_ptr = topk_weights.device_ptr(topk_weights.stream()).0 as *const f32;
+ weights_ptr
+ } else {
+ std::ptr::null() as *const f32
+ };
+
+ let output = unsafe { dev.alloc::(size_m * size_n) }?;
+ let expert_counts = unsafe { dev.alloc::(num_experts) }?;
+ let expert_offsets = unsafe { dev.alloc::(num_experts + 1) }?;
+
+ let stream = dev.cuda_stream().cu_stream() as i64;
+ use core::ffi::c_void;
+
+ unsafe {
+ ffi::moe_gemm_wmma(
+ input.device_ptr(input.stream()).0 as *const c_void, // [size_m, size_k]
+ weights.device_ptr(weights.stream()).0 as *const c_void, // [num_experts, size_n, size_k]
+ sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32,
+ experts_ids.device_ptr(experts_ids.stream()).0 as *const i32,
+ topk_weights_ptr,
+ output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n]
+ expert_counts.device_ptr(expert_counts.stream()).0 as *mut i32, // pre-allocated buffer [num_experts]
+ expert_offsets.device_ptr(expert_offsets.stream()).0 as *mut i32, // pre-allocated buffer [num_experts + 1]
+ num_experts as i32,
+ topk as i32,
+ size_m as i32,
+ size_n as i32,
+ size_k as i32,
+ data_type as i32, // 0=float16, 1=bf16 (for input/output)
+ is_prefill,
+ stream as i64,
+ );
+ }
+
+ use candle::op::BackpropOp;
+ let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone());
+ let output = Tensor::from_storage(
+ candle::Storage::Cuda(output),
+ (size_m, size_n),
+ BackpropOp::none(),
+ false,
+ );
+
+ Ok(output)
+ }
+
+ match input.dtype() {
+ DType::F16 => cuda_fwd::(
+ input,
+ weights,
+ topk_weights,
+ sorted_token_ids,
+ experts_ids,
+ topk,
+ is_prefill,
+ ),
+ DType::BF16 => cuda_fwd::(
+ input,
+ weights,
+ topk_weights,
+ sorted_token_ids,
+ experts_ids,
+ topk,
+ is_prefill,
+ ),
+ _ => {
+ candle::bail!("moe_gemm only accepts f16/bf16 inputs")
+ }
+ }
+}
+
+#[cfg(not(feature = "cuda"))]
+pub fn moe_gemm(
+ _: &Tensor,
+ _: &Tensor,
+ _: &Option,
+ _: &Tensor,
+ _: &Tensor,
+ _: usize,
+ _: bool,
+) -> Result {
+ candle::bail!("moe_gemm is only implemented for the cuda backend")
+}
+
+#[cfg(feature = "cuda")]
+pub fn moe_gemm_gguf(
+ input: &Tensor,
+ weights: &QTensor,
+ topk_weights: &Option,
+ sorted_token_ids: &Tensor,
+ experts_ids: &Tensor,
+ topk: usize,
+ is_prefill: bool,
+ dtype: candle::DType,
+) -> Result {
+ use candle::cuda_backend::cudarc::driver::DevicePtr;
+ use candle::quantized::GgmlDType;
+ use candle::DType;
+ use half::{bf16, f16};
+
+ fn cuda_fwd(
+ input: &Tensor,
+ weights: &QTensor,
+ topk_weights: &Option,
+ sorted_token_ids: &Tensor,
+ experts_ids: &Tensor,
+ topk: usize,
+ is_prefill: bool,
+ dtype: DType,
+ ) -> Result {
+ let (mut size_m, size_k) = input.dims2()?;
+ if topk_weights.is_none() {
+ size_m *= topk;
+ }
+ let (num_experts, size_n, size_k1) = weights.shape().dims3()?;
+ assert!(
+ size_k == size_k1,
+ "input {:?} and weight {:?} last dim mismatch!",
+ size_k,
+ size_k1,
+ );
+ let dev = input.device().as_cuda_device()?;
+
+ // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5
+ let gguf_dtype = match weights.dtype() {
+ GgmlDType::Q8_0 => 0,
+ GgmlDType::Q4K => 1,
+ GgmlDType::Q2K => 2,
+ GgmlDType::Q3K => 3,
+ GgmlDType::Q5K => 4,
+ GgmlDType::Q6K => 5,
+ _ => {
+ candle::bail!(
+ "moe_gemm_gguf `ISQ` only accept q2k, q3k, q4k, q5k, q6k or q8_0 weights!"
+ )
+ }
+ };
+
+ let weight_ptr = weights.device_ptr()?;
+
+ let topk_weights_ptr = if let Some(topk_weights) = &topk_weights {
+ let (topk_weights, _) = topk_weights.storage_and_layout();
+ let topk_weights = match &*topk_weights {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("topk_weights must be a cuda tensor"),
+ };
+ let w_ptr = topk_weights.device_ptr(topk_weights.stream()).0 as *const f32;
+ w_ptr
+ } else {
+ std::ptr::null() as *const f32
+ };
+
+ let (sorted_token_ids, _) = sorted_token_ids.storage_and_layout();
+ let sorted_token_ids = match &*sorted_token_ids {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("sorted_token_ids must be a cuda tensor"),
+ };
+ let (experts_ids, _) = experts_ids.storage_and_layout();
+ let experts_ids = match &*experts_ids {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("experts_ids must be a cuda tensor"),
+ };
+
+ let output = unsafe { dev.alloc::(size_m * size_n) }?;
+ let stream = dev.cuda_stream().cu_stream() as i64;
+ use candle::op::BackpropOp;
+ use core::ffi::c_void;
+
+ assert!(size_k % 8 == 0, "size_k must divisible by 8");
+ unsafe {
+ if is_prefill {
+ let input = input.to_dtype(dtype)?;
+ let (input, _) = input.storage_and_layout();
+ let (input_ptr, input_dtype) = match &*input {
+ candle::Storage::Cuda(c) => {
+ if dtype == DType::F16 {
+ let c = c.as_cuda_slice::()?;
+ (c.device_ptr(c.stream()).0 as *const c_void, 0)
+ } else {
+ let c = c.as_cuda_slice::()?;
+ (c.device_ptr(c.stream()).0 as *const c_void, 1)
+ }
+ }
+ _ => candle::bail!("input must be a cuda tensor"),
+ };
+ ffi::moe_gemm_gguf_prefill(
+ input_ptr, // [size_m or size_m/topk, size_k]
+ weight_ptr as *const u8, // [num_experts, size_n, size_k]
+ sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32,
+ experts_ids.device_ptr(experts_ids.stream()).0 as *const i32,
+ topk_weights_ptr,
+ output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n]
+ num_experts as i32,
+ topk as i32,
+ size_m as i32,
+ size_n as i32,
+ size_k as i32,
+ input_dtype as i32,
+ gguf_dtype as i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weight)
+ stream as i64,
+ );
+ } else {
+ let (input, _) = input.storage_and_layout();
+ let input = match &*input {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::()?,
+ _ => candle::bail!("input must be a cuda tensor"),
+ };
+
+ ffi::moe_gemm_gguf(
+ input.device_ptr(input.stream()).0 as *const f32, // [size_m or size_m/topk, size_k]
+ weight_ptr as *const c_void, // [num_experts, size_n, size_k]
+ sorted_token_ids.device_ptr(sorted_token_ids.stream()).0 as *const i32,
+ experts_ids.device_ptr(experts_ids.stream()).0 as *const i32,
+ topk_weights_ptr,
+ output.device_ptr(output.stream()).0 as *mut c_void, // [size_m, size_n]
+ num_experts as i32,
+ topk as i32,
+ size_m as i32,
+ size_n as i32,
+ size_k as i32,
+ gguf_dtype as i32, // Q8_0: 0, Q4K: 1, Q2K: 2, Q3k: 3, Q5K: 4, Q6K: 5 (for weight)
+ stream as i64,
+ );
+ }
+ }
+
+ let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone());
+ let output = Tensor::from_storage(
+ candle::Storage::Cuda(output),
+ (size_m, size_n),
+ BackpropOp::none(),
+ false,
+ );
+
+ Ok(output)
+ }
+
+ match input.dtype() {
+ DType::F32 => cuda_fwd(
+ input,
+ weights,
+ topk_weights,
+ sorted_token_ids,
+ experts_ids,
+ topk,
+ is_prefill,
+ dtype,
+ ),
+ _ => {
+ candle::bail!("moe_gemm_gguf only accepts f32 inputs")
+ }
+ }
+}
+
+#[cfg(not(feature = "cuda"))]
+#[allow(clippy::too_many_arguments)]
+pub fn moe_gemm_gguf(
+ _: &Tensor,
+ _: &QTensor,
+ _: &Option,
+ _: &Tensor,
+ _: &Tensor,
+ _: usize,
+ _: bool,
+ _: candle::DType,
+) -> Result {
+ candle::bail!("moe_gemm_gguf is only implemented for the cuda backend")
+}
diff --git a/candle-transformers/src/fused_moe.rs b/candle-transformers/src/fused_moe.rs
new file mode 100644
index 0000000000..da2c6cf912
--- /dev/null
+++ b/candle-transformers/src/fused_moe.rs
@@ -0,0 +1,302 @@
+// Adapted from: https://github.com/guoqingbao/vllm.rs/blob/main/src/models/layers/moe.rs
+use candle::Module;
+use candle::{quantized::QTensor, DType, Result, Tensor, D};
+use candle_nn::{linear_no_bias, moe, Activation, Linear, VarBuilder};
+use std::sync::Arc;
+
+pub struct MoeCfg {
+ pub hidden_size: usize,
+ pub num_experts: usize,
+ pub num_experts_per_tok: usize,
+ pub moe_intermediate_size: usize,
+ pub norm_topk_prob: bool,
+ pub act: Activation,
+ pub decoder_sparse_step: Option,
+}
+
+#[allow(dead_code)]
+#[derive(Debug, Clone)]
+pub struct FusedMoe {
+ gate: Linear,
+ gate_up_w: Tensor,
+ down_w: Tensor,
+ w_size_n: usize,
+ act: Activation,
+ norm_topk_prob: bool,
+ num_experts_per_tok: usize,
+ // world_size: usize,
+ dtype: DType,
+}
+
+impl FusedMoe {
+ pub fn new(cfg: &MoeCfg, vb: VarBuilder, dtype: DType) -> Result {
+ let num_experts = cfg.num_experts;
+
+ let gate = linear_no_bias(cfg.hidden_size, num_experts, vb.pp("gate"))?;
+
+ let experts_vb = vb.pp("experts");
+ let mut gate_up_experts = Vec::with_capacity(num_experts);
+ let mut down_experts = Vec::with_capacity(num_experts);
+
+ //pack experts
+ for i in 0..num_experts {
+ let experts_vb = experts_vb.pp(format!("{i}").as_str());
+
+ let (gate_up_expert, down_expert) = {
+ // n x k format
+ let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
+ let gate_expert = experts_vb.pp("gate_proj").get_with_hints(
+ (cfg.moe_intermediate_size, cfg.hidden_size),
+ "weight",
+ init_ws,
+ )?;
+ let up_expert = experts_vb.pp("up_proj").get_with_hints(
+ (cfg.moe_intermediate_size, cfg.hidden_size),
+ "weight",
+ init_ws,
+ )?;
+ let down_expert = experts_vb.pp("down_proj").get_with_hints(
+ (cfg.hidden_size, cfg.moe_intermediate_size),
+ "weight",
+ init_ws,
+ )?;
+ //pack gate_proj and up_proj
+ let gate_up_expert = Tensor::cat(&[&gate_expert, &up_expert], 0)?;
+
+ (gate_up_expert, down_expert)
+ };
+
+ gate_up_experts.push(gate_up_expert);
+ down_experts.push(down_expert);
+ }
+
+ let gate_up_w = Tensor::stack(&gate_up_experts, 0)?;
+ let down_w = Tensor::stack(&down_experts, 0)?;
+ // let world_size = comm.world_size();
+ let w_size_n = gate_up_w.dim(1)? / 2;
+
+ Ok(Self {
+ gate,
+ gate_up_w,
+ down_w,
+ w_size_n,
+ act: cfg.act,
+ norm_topk_prob: cfg.norm_topk_prob,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ // world_size,
+ dtype,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result {
+ let (batch, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let (num_tokens, hidden_dim) = xs.dims2()?;
+
+ let router_logits = self.gate.forward(&xs)?;
+
+ let routing_weights =
+ candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;
+
+ let topk_ids = routing_weights
+ .arg_sort_last_dim(false)?
+ .narrow(D::Minus1, 0, self.num_experts_per_tok)?
+ .contiguous()?;
+
+ let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?;
+
+ if self.norm_topk_prob {
+ topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?;
+ }
+
+ let (expert_ids, sorted_token_ids) = if is_prefill {
+ // For long-context (32K+), need to use custom sort kernel
+ // #[cfg(feature = "cuda")]
+ // {
+ // use attention_rs::sort::ArgSortOp;
+ // topk_ids.flatten_all()?.sort(true)?
+ // }
+ // #[cfg(not(feature = "cuda"))]
+ topk_ids.flatten_all()?.sort_last_dim(true)?
+ } else {
+ topk_ids.flatten_all()?.sort_last_dim(true)?
+ };
+
+ //out (M, top_k, N)
+ let gate_up = moe::moe_gemm(
+ &xs,
+ &self.gate_up_w,
+ &None,
+ &sorted_token_ids,
+ &expert_ids,
+ self.num_experts_per_tok,
+ is_prefill,
+ )?;
+
+ let gate = gate_up
+ .narrow(candle::D::Minus1, 0, self.w_size_n)?
+ .contiguous()?;
+ let up = gate_up
+ .narrow(candle::D::Minus1, self.w_size_n, self.w_size_n)?
+ .contiguous()?;
+
+ //(M * top_k, N // 2)
+ let down_inputs = (up * gate.apply(&self.act)?)?.reshape(((), self.w_size_n))?;
+
+ //view(M, top_k, K) -> sum -> (M, K)
+ let ys = moe::moe_gemm(
+ &down_inputs,
+ &self.down_w,
+ &Some(topk_weights),
+ &sorted_token_ids,
+ &expert_ids,
+ self.num_experts_per_tok,
+ is_prefill,
+ )?
+ .reshape((num_tokens, (), hidden_dim))?
+ .sum(D::Minus2)?;
+
+ ys.reshape((batch, seq_len, hidden_dim))
+ }
+}
+
+pub struct FusedMoeGGUF {
+ pub gate: Linear,
+ pub gate_experts: Arc,
+ pub up_experts: Arc,
+ pub down_experts: Arc,
+ pub act: Activation,
+ pub norm_topk_prob: bool,
+ pub num_experts_per_tok: usize,
+ // all_reduce: AllReduce,
+ // world_size: usize,
+ pub dtype: DType,
+}
+
+impl FusedMoeGGUF {
+ pub fn new(
+ cfg: &MoeCfg,
+ vb: crate::quantized_var_builder::VarBuilder,
+ dtype: DType,
+ ) -> Result {
+ let num_experts = cfg.num_experts;
+ let gate_ws = vb
+ .pp("ffn_gate_inp")
+ .get((num_experts, cfg.hidden_size), "weight")?
+ .dequantize(vb.device())?
+ .to_dtype(DType::F32)?;
+
+ let gate = Linear::new(gate_ws, None);
+
+ let (gate_experts, up_experts, down_experts) = {
+ (
+ vb.pp("ffn_gate_exps").get(
+ (num_experts, cfg.moe_intermediate_size, cfg.hidden_size),
+ "weight",
+ )?,
+ vb.pp("ffn_up_exps").get(
+ (num_experts, cfg.moe_intermediate_size, cfg.hidden_size),
+ "weight",
+ )?,
+ vb.pp("ffn_down_exps").get(
+ (num_experts, cfg.hidden_size, cfg.moe_intermediate_size),
+ "weight",
+ )?,
+ )
+ };
+
+ Ok(Self {
+ gate,
+ gate_experts,
+ up_experts,
+ down_experts,
+ act: cfg.act,
+ norm_topk_prob: cfg.norm_topk_prob,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ // all_reduce: AllReduce::new(comm),
+ // world_size: 1,
+ dtype,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result {
+ let (batch, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let (num_tokens, hidden_dim) = xs.dims2()?;
+ let original_dtype = xs.dtype();
+ let xs = if xs.dtype() != DType::F32 {
+ xs.to_dtype(DType::F32)?
+ } else {
+ xs.to_owned()
+ };
+
+ let router_logits = self.gate.forward(&xs)?;
+
+ let routing_weights =
+ candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;
+
+ let topk_ids = routing_weights
+ .arg_sort_last_dim(false)?
+ .narrow(D::Minus1, 0, self.num_experts_per_tok)?
+ .contiguous()?;
+
+ let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?;
+
+ if self.norm_topk_prob {
+ topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?;
+ }
+
+ let (expert_ids, sorted_token_ids) = if is_prefill {
+ // For long-context (32K+), need to use custom sort kernel
+ // #[cfg(feature = "cuda")]
+ // {
+ // use attention_rs::sort::ArgSortOp;
+ // topk_ids.flatten_all()?.sort(true)?
+ // }
+ // #[cfg(not(feature = "cuda"))]
+ topk_ids.flatten_all()?.sort_last_dim(true)?
+ } else {
+ topk_ids.flatten_all()?.sort_last_dim(true)?
+ };
+
+ let ys = {
+ let gate = moe::moe_gemm_gguf(
+ &xs,
+ &self.gate_experts,
+ &None,
+ &sorted_token_ids,
+ &expert_ids,
+ self.num_experts_per_tok,
+ is_prefill,
+ self.dtype,
+ )?;
+ let up = moe::moe_gemm_gguf(
+ &xs,
+ &self.up_experts,
+ &None,
+ &sorted_token_ids,
+ &expert_ids,
+ self.num_experts_per_tok,
+ is_prefill,
+ self.dtype,
+ )?;
+
+ let down_inputs = (up * gate.apply(&self.act)?)?;
+ moe::moe_gemm_gguf(
+ &down_inputs,
+ &self.down_experts,
+ &Some(topk_weights),
+ &sorted_token_ids,
+ &expert_ids,
+ self.num_experts_per_tok,
+ is_prefill,
+ self.dtype,
+ )?
+ };
+ let mut ys = ys.reshape((num_tokens, (), hidden_dim))?.sum(D::Minus2)?;
+ if ys.dtype() != original_dtype {
+ ys = ys.to_dtype(original_dtype)?;
+ }
+ ys.reshape((batch, seq_len, hidden_dim))
+ }
+}
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index b2b062a9d7..bae7699a09 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -1,3 +1,4 @@
+pub mod fused_moe;
pub mod generation;
pub mod models;
pub mod object_detection;
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index e77ba4a36f..2d93833581 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -94,6 +94,7 @@ pub mod quantized_phi;
pub mod quantized_phi3;
pub mod quantized_qwen2;
pub mod quantized_qwen3;
+pub mod quantized_qwen3_moe;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs
index 5d9f414658..85ccbb0edd 100644
--- a/candle-transformers/src/models/quantized_qwen3.rs
+++ b/candle-transformers/src/models/quantized_qwen3.rs
@@ -14,32 +14,32 @@ use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
use std::io::{Read, Seek};
use std::sync::Arc;
-struct Gguf {
+pub struct Gguf {
ct: gguf_file::Content,
reader: R,
device: Device,
}
impl Gguf {
- fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
+ pub fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
Self { ct, reader, device }
}
- fn qmatmul(&mut self, name: &str) -> Result {
+ pub fn qmatmul(&mut self, name: &str) -> Result {
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
QMatMul::from_weights(ws.into())
}
- fn rms_norm(&mut self, name: &str, eps: f64) -> Result {
+ pub fn rms_norm(&mut self, name: &str, eps: f64) -> Result {
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
RmsNorm::from_qtensor(ws, eps)
}
- fn metadata(&self) -> &std::collections::HashMap {
+ pub fn metadata(&self) -> &std::collections::HashMap {
&self.ct.metadata
}
- fn tensor(&mut self, name: &str) -> Result {
+ pub fn tensor(&mut self, name: &str) -> Result {
self.ct.tensor(&mut self.reader, name, &self.device)
}
}
@@ -81,13 +81,13 @@ impl Module for MlpWeights {
}
#[derive(Debug, Clone)]
-struct RotaryEmbedding {
+pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
- fn new(
+ pub fn new(
dtype: DType,
head_dim: usize,
max_position_embeddings: usize,
@@ -113,7 +113,7 @@ impl RotaryEmbedding {
}
/// Apply RoPE (q, k shape: B x H x L x D)
- fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
+ pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs
new file mode 100644
index 0000000000..2daa84e062
--- /dev/null
+++ b/candle-transformers/src/models/quantized_qwen3_moe.rs
@@ -0,0 +1,451 @@
+use super::quantized_qwen3::{Gguf, RotaryEmbedding};
+use super::with_tracing::QMatMul;
+use crate::fused_moe::{FusedMoeGGUF, MoeCfg};
+use crate::quantized_nn::RmsNorm;
+use crate::utils::repeat_kv;
+use candle::quantized::gguf_file;
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::kv_cache::ConcatKvCache;
+use candle_nn::Linear;
+use candle_nn::{Embedding, Module};
+use std::sync::Arc;
+#[derive(Debug, Clone)]
+struct Mlp {
+ feed_forward_w1: QMatMul,
+ feed_forward_w2: QMatMul,
+ feed_forward_w3: QMatMul,
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result {
+ let w1 = self.feed_forward_w1.forward(xs)?;
+ let w3 = self.feed_forward_w3.forward(xs)?;
+ self.feed_forward_w2
+ .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
+ }
+}
+
+enum MoeOrMlp {
+ FusedMoe(FusedMoeGGUF),
+ Mlp(Mlp),
+}
+
+impl MoeOrMlp {
+ fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result {
+ match self {
+ Self::Mlp(m) => m.forward(xs),
+ Self::FusedMoe(m) => m.forward(xs, is_prefill),
+ }
+ }
+}
+
+pub struct QuantizedAttention {
+ attention_wq: QMatMul,
+ attention_wk: QMatMul,
+ attention_wv: QMatMul,
+ attention_bq: Option,
+ attention_bk: Option,
+ attention_bv: Option,
+ attention_wo: QMatMul,
+ q_norm: Option,
+ k_norm: Option,
+ n_head: usize,
+ n_kv_head: usize,
+ head_dim: usize,
+ num_kv_groups: usize,
+ rotary_emb: Arc,
+ dtype: DType,
+ kv_cache: ConcatKvCache,
+}
+
+impl QuantizedAttention {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ gg: &mut Gguf,
+ prefix: &str,
+ dtype: DType,
+ num_heads: usize,
+ num_kv_heads: usize,
+ head_dim: usize,
+ rms_norm_eps: f64,
+ device: &Device,
+ rotary_emb: Arc,
+ ) -> Result {
+ let num_kv_groups = num_heads / num_kv_heads;
+ let attention_wq = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?;
+ let attention_wk = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?;
+ let attention_wv = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?;
+
+ let attention_bq = gg.tensor(&format!("{prefix}.attn_q.bias"));
+ let attention_bk = gg.tensor(&format!("{prefix}.attn_k.bias"));
+ let attention_bv = gg.tensor(&format!("{prefix}.attn_v.bias"));
+
+ let attention_bq = if let Ok(attention_bq) = attention_bq {
+ Some(attention_bq.dequantize(device)?.to_dtype(DType::F32)?)
+ } else {
+ None
+ };
+
+ let attention_bk = if let Ok(attention_bk) = attention_bk {
+ Some(attention_bk.dequantize(device)?.to_dtype(DType::F32)?)
+ } else {
+ None
+ };
+
+ let attention_bv = if let Ok(attention_bv) = attention_bv {
+ Some(attention_bv.dequantize(device)?.to_dtype(DType::F32)?)
+ } else {
+ None
+ };
+
+ let attention_wo = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?;
+ let q_norm = Some(gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?);
+ let k_norm = Some(gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?);
+ let kv_cache = ConcatKvCache::new(2);
+ Ok(QuantizedAttention {
+ attention_wq,
+ attention_wk,
+ attention_wv,
+ attention_bq,
+ attention_bk,
+ attention_bv,
+ attention_wo,
+ q_norm,
+ k_norm,
+ n_head: num_heads,
+ n_kv_head: num_kv_heads,
+ head_dim,
+ num_kv_groups,
+ rotary_emb: rotary_emb.clone(),
+ dtype,
+ kv_cache,
+ })
+ }
+
+ pub fn forward(
+ &mut self,
+ x: &Tensor,
+ mask: Option<&Tensor>,
+ input_pos: usize,
+ ) -> Result {
+ let (b, seq_len, _) = x.dims3()?;
+ let in_dtype = x.dtype();
+ let q = self.attention_wq.forward(x)?;
+ let k = self.attention_wk.forward(x)?;
+ let v = self.attention_wv.forward(x)?;
+
+ let q = if self.attention_bq.is_some() {
+ q.broadcast_add(self.attention_bq.as_ref().unwrap())?
+ } else {
+ q
+ };
+
+ let k = if self.attention_bk.is_some() {
+ k.broadcast_add(self.attention_bk.as_ref().unwrap())?
+ } else {
+ k
+ };
+
+ let v = if self.attention_bv.is_some() {
+ v.broadcast_add(self.attention_bv.as_ref().unwrap())?
+ } else {
+ v
+ };
+
+ let q = q
+ .reshape((1, seq_len, self.n_head, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let k = k
+ .reshape((1, seq_len, self.n_kv_head, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let v = v
+ .reshape((1, seq_len, self.n_kv_head, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ let (q, k) = if let (Some(q_norm), Some(k_norm)) = (&self.q_norm, &self.k_norm) {
+ // Per‑head RMSNorm in qwen3
+ let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later
+ let k_flat = k.flatten(0, 2)?;
+
+ // q_norm and k_norm weights stored in f32 format in qwen3 gguf
+ let q_flat = q_norm.forward(&q_flat)?;
+ let k_flat = k_norm.forward(&k_flat)?;
+
+ let q = q_flat.reshape((1, self.n_head, seq_len, self.head_dim))?;
+ let k = k_flat.reshape((1, self.n_kv_head, seq_len, self.head_dim))?;
+
+ (q, k)
+ } else {
+ (q, k)
+ };
+
+ let (q, k, v) = (
+ q.to_dtype(self.dtype)?,
+ k.to_dtype(self.dtype)?,
+ v.to_dtype(self.dtype)?,
+ );
+
+ let (q, k) = self.rotary_emb.apply(&q, &k, input_pos)?;
+
+ let (k, v) = self.kv_cache.append(&k, &v)?;
+
+ let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
+ let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
+
+ let scale = 1.0 / (self.head_dim as f64).sqrt();
+ let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
+
+ if let Some(m) = mask {
+ let m_dtype = m.dtype();
+ let scores_dtype = scores.dtype();
+ let mask = if m_dtype != scores_dtype {
+ m.to_dtype(scores_dtype)?
+ } else {
+ m.clone()
+ };
+ scores = scores.broadcast_add(&mask)?;
+ }
+
+ let probs = candle_nn::ops::softmax_last_dim(&scores)?;
+ let ctx = probs.matmul(&v)?; // (B, H, L, D)
+ let reshaped_ctx =
+ ctx.transpose(1, 2)?
+ .reshape((b, seq_len, self.n_head * self.head_dim))?;
+
+ self.attention_wo.forward(&reshaped_ctx.to_dtype(in_dtype)?)
+ }
+}
+
+struct LayerWeights {
+ self_attn: QuantizedAttention,
+ attention_norm: RmsNorm,
+ mlp: MoeOrMlp,
+ ffn_norm: RmsNorm,
+}
+
+impl LayerWeights {
+ fn forward_attn(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result {
+ self.self_attn.forward(x, mask, offset)
+ }
+}
+
+pub struct GGUFQWenMoE {
+ tok_embeddings: Embedding,
+ layers: Vec,
+ norm: RmsNorm,
+ output: QMatMul,
+ dtype: DType,
+ device: Device,
+}
+
+impl GGUFQWenMoE {
+ pub fn from_gguf(
+ ct: gguf_file::Content,
+ reader: &mut R,
+ device: &Device,
+ dtype: DType,
+ ) -> Result {
+ let mut gg = Gguf::new(ct, reader, device.clone());
+ let md_get = |s: &str| match gg.metadata().get(s) {
+ None => candle::bail!("cannot find {s} in metadata"),
+ Some(v) => Ok(v),
+ };
+ let arch = md_get("general.architecture")?.to_string()?;
+
+ let head_count =
+ md_get(format!("{arch}.attention.head_count").as_str())?.to_u32()? as usize;
+ let head_count_kv =
+ md_get(format!("{arch}.attention.head_count_kv").as_str())?.to_u32()? as usize;
+
+ let head_dim = md_get(format!("{arch}.attention.key_length").as_str());
+ let embedding_length =
+ md_get(format!("{arch}.embedding_length").as_str())?.to_u32()? as usize;
+ let head_dim = if let Ok(head_dim) = head_dim {
+ head_dim.to_u32()? as usize
+ } else {
+ embedding_length / head_count
+ };
+ let context_length = md_get(format!("{arch}.context_length").as_str())?.to_u32()? as usize;
+ let block_count = md_get(format!("{arch}.block_count").as_str())?.to_u32()? as usize;
+ let rms_norm_eps =
+ md_get(format!("{arch}.attention.layer_norm_rms_epsilon").as_str())?.to_f32()? as f64;
+ let rope_freq_base = md_get(format!("{arch}.rope.freq_base").as_str())
+ .and_then(|m| m.to_f32())
+ .unwrap_or(10000f32);
+ let expert_shared_feed_forward_length =
+ md_get(format!("{arch}.expert_shared_feed_forward_length").as_str());
+ let shared_expert_intermediate_size = match expert_shared_feed_forward_length {
+ Ok(length) => {
+ if length.to_u32()? > 0 {
+ Some(length.to_u32()? as usize)
+ } else {
+ None
+ }
+ }
+ _ => None,
+ };
+
+ let moe_cfg = MoeCfg {
+ moe_intermediate_size: md_get(format!("{arch}.expert_feed_forward_length").as_str())?
+ .to_u32()? as usize,
+ num_experts: md_get(format!("{arch}.expert_count").as_str())?.to_u32()? as usize,
+ norm_topk_prob: shared_expert_intermediate_size.is_none(),
+ num_experts_per_tok: md_get(format!("{arch}.expert_used_count").as_str())?.to_u32()?
+ as usize,
+ hidden_size: head_dim,
+ act: candle_nn::Activation::Silu,
+ decoder_sparse_step: None,
+ };
+
+ let tok_embeddings = gg.tensor("token_embd.weight")?;
+ let tok_embeddings = tok_embeddings.dequantize(device)?;
+ let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
+ let output = match gg.qmatmul("output.weight") {
+ Ok(v) => v,
+ _ => {
+ // use tie_word_embeddings
+ gg.qmatmul("token_embd.weight")?
+ }
+ };
+
+ let rotary_emb = Arc::new(RotaryEmbedding::new(
+ dtype,
+ head_dim,
+ context_length,
+ rope_freq_base as f64,
+ device,
+ )?);
+ let mut layers = Vec::with_capacity(block_count);
+ for layer_idx in 0..block_count {
+ let prefix = format!("blk.{layer_idx}");
+ let mlp = if moe_cfg.num_experts > 0
+ && (layer_idx + 1) % moe_cfg.decoder_sparse_step.unwrap_or(1) == 0
+ {
+ let gate_ws = gg
+ .tensor(&format!("{prefix}.ffn_gate_inp.weight"))?
+ .dequantize(device)?
+ .to_dtype(DType::F32)?;
+ let gate = Linear::new(gate_ws, None);
+ let gate_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_gate_exps.weight"))?);
+ let up_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_up_exps.weight"))?);
+ let down_experts = Arc::new(gg.tensor(&format!("{prefix}.ffn_down_exps.weight"))?);
+ let moe = FusedMoeGGUF {
+ gate,
+ gate_experts,
+ up_experts,
+ down_experts,
+ act: candle_nn::Activation::Silu,
+ norm_topk_prob: moe_cfg.norm_topk_prob,
+ num_experts_per_tok: moe_cfg.num_experts_per_tok,
+ dtype,
+ };
+
+ MoeOrMlp::FusedMoe(moe)
+ } else {
+ let mlp = {
+ let feed_forward_w1 = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?;
+ let feed_forward_w2 = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?;
+ let feed_forward_w3 = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?;
+ Mlp {
+ feed_forward_w1,
+ feed_forward_w2,
+ feed_forward_w3,
+ }
+ };
+ MoeOrMlp::Mlp(mlp)
+ };
+
+ let attention_norm =
+ gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?;
+ let ffn_norm = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?;
+
+ let self_attn = QuantizedAttention::new(
+ &mut gg,
+ &prefix,
+ dtype,
+ head_count,
+ head_count_kv,
+ head_dim,
+ rms_norm_eps,
+ device,
+ rotary_emb.clone(),
+ )?;
+ layers.push(LayerWeights {
+ self_attn,
+ attention_norm,
+ mlp,
+ ffn_norm,
+ });
+ }
+
+ Ok(Self {
+ tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
+ layers,
+ norm,
+ output,
+ dtype,
+ device: device.clone(),
+ })
+ }
+
+ fn causal_mask(
+ &self,
+ b: usize,
+ tgt: usize,
+ offset: usize,
+ sw: Option,
+ ) -> Result {
+ let minf = f32::NEG_INFINITY;
+ let mask: Vec<_> = (0..tgt)
+ .flat_map(|i| {
+ (0..(tgt + offset)).map(move |j| {
+ let past_ok = j <= i + offset;
+ let sw_ok = match sw {
+ Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
+ None => true,
+ };
+ if past_ok && sw_ok {
+ 0.
+ } else {
+ minf
+ }
+ })
+ })
+ .collect();
+ Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
+ }
+
+ pub fn forward(&mut self, x: &Tensor, offset: usize) -> Result {
+ let mut xs = self.tok_embeddings.forward(x)?;
+ let (b, l) = x.dims2()?;
+
+ let causal_mask = if l == 1 {
+ None
+ } else {
+ Some(self.causal_mask(b, l, offset, None)?)
+ };
+
+ for layer in self.layers.iter_mut() {
+ let x = xs;
+ let residual = &x;
+
+ let x = layer.attention_norm.forward(&x)?;
+ let attn = layer.forward_attn(&x, causal_mask.as_ref(), offset)?;
+ let x = (attn + residual)?;
+
+ // MLP
+ let residual = &x;
+ let x = layer.ffn_norm.forward(&x)?;
+ let x = layer.mlp.forward(&x, causal_mask.is_some())?;
+ let x = (x + residual)?;
+ xs = x
+ }
+
+ let xs = xs.narrow(1, l - 1, 1)?;
+ let xs = self.norm.forward(&xs)?;
+ self.output.forward(&xs)?.to_dtype(DType::F32)?.squeeze(1)
+ }
+}
diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs
index b76ce92de4..0576b4c075 100644
--- a/candle-transformers/src/models/qwen3_moe.rs
+++ b/candle-transformers/src/models/qwen3_moe.rs
@@ -1,6 +1,9 @@
-use crate::models::{
- qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
- with_tracing::{linear_no_bias, Linear, RmsNorm},
+use crate::{
+ fused_moe::{FusedMoe, MoeCfg},
+ models::{
+ qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
+ with_tracing::{linear_no_bias, Linear, RmsNorm},
+ },
};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
@@ -176,14 +179,16 @@ impl Module for Qwen3SparseMoeBlock {
#[derive(Debug, Clone)]
enum Qwen3FeedForward {
Mlp(Qwen3MLP),
- MoE(Qwen3SparseMoeBlock),
+ NaiveMoE(Qwen3SparseMoeBlock),
+ FusedMoE(FusedMoe),
}
-impl Module for Qwen3FeedForward {
- fn forward(&self, xs: &Tensor) -> Result {
+impl Qwen3FeedForward {
+ fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result {
match self {
Self::Mlp(m) => m.forward(xs),
- Self::MoE(m) => m.forward(xs),
+ Self::NaiveMoE(m) => m.forward(xs),
+ Self::FusedMoE(m) => m.forward(xs, is_prefill),
}
}
}
@@ -205,10 +210,24 @@ impl DecoderLayer {
) -> Result {
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
+ let moe_cfg = MoeCfg {
+ hidden_size: cfg.hidden_size,
+ num_experts: cfg.num_experts,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ moe_intermediate_size: cfg.moe_intermediate_size,
+ norm_topk_prob: cfg.norm_topk_prob,
+ act: cfg.hidden_act,
+ decoder_sparse_step: None,
+ };
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
let feed_forward =
if cfg.num_experts > 0 && (layer_idx + 1).is_multiple_of(cfg.decoder_sparse_step) {
- Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
+ if cfg!(feature = "cuda") {
+ // Use fused MoE kernel on CUDA
+ Qwen3FeedForward::FusedMoE(FusedMoe::new(&moe_cfg, vb.pp("mlp"), vb.dtype())?)
+ } else {
+ Qwen3FeedForward::NaiveMoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
+ }
} else {
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
};
@@ -233,7 +252,7 @@ impl DecoderLayer {
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
- let h2 = h2.apply(&self.feed_forward)?;
+ let h2 = self.feed_forward.forward(&h2, mask.is_some())?;
x + h2
}