Skip to content

Conversation

@danielclough
Copy link
Contributor

Summary

Adds KV cache management and fixes critical causal mask bug for Qwen2 multi-turn inference. Includes numerical precision improvements for RoPE and attention.

Changes

  • Causal mask bug fix: Corrects mask shape for cached decoding (was [tgt, tgt], now [tgt, total]) - critical for multi-turn conversations
  • Precision improvements: RoPE and softmax now use F32 intermediates to match PyTorch behavior
  • KV cache API: Adds extract_kv_cache/restore_kv_cache methods for cache manipulation and inspection
  • Selective attention: New prepare_4d_causal_attention_mask_with_cache_position for non-contiguous cache positions
  • Embedding injection: forward_from_embeds methods enable custom embedding workflows (e.g., multimodal)
  • Stability fix: Replaces NEG_INFINITY with f32::MIN to avoid NaN propagation when combining masks
  • Cache manipulation: Adds shift_kv_cache_first_to_last for advanced patterns (e.g., negative prompt refresh)

Motivation

The causal mask bug prevented proper multi-turn decoding with KV cache. The new cache management APIs enable advanced inference patterns like streaming audio generation (VibeVoice) and speculative decoding while maintaining precision for F16/BF16 inference.

Breaking Changes

None - all changes are backward compatible additions or bug fixes.

✅ Validation

Routine

cargo fmt --all
cargo test -p candle-transformers
cargo clippy -p candle-transformers

Test Qwen2 Example

Simple Query

cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. ." --model "2-1.5b"

Test with very short prompt to ensure single-token decode works

cargo run --example qwen --features metal --release -- --prompt "Hi" --sample-len 10 --model "2-1.5b"

- Fix causal mask shape for cached decoding (critical for multi-turn)
- Add extract/restore methods for KV cache manipulation
- Add support for non-contiguous cache positions via `cache_position`
- Add forward_from_embeds for custom embedding workflows
- Improve RoPE and softmax precision with F32 intermediates (matching PyTorch)
- Replace NEG_INFINITY with f32::MIN to avoid NaN propagation
Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of good stuff here!
A slight issue is that on cuda performance drops by 20%. Would be nice to figure out if we can easily avoid this drop before merging (I'm ok with accuracy > performance)

Comment on lines +422 to +436
for &abs_query_pos in cache_pos_vec.iter().take(query_length) {
let abs_query_pos = abs_query_pos as usize;
for j in 0..key_length {
// Causal: can't attend to future positions
let is_future = j > abs_query_pos;
// Sliding window: can't attend to positions too far in the past
let is_too_old = j + self.sliding_window < abs_query_pos;

if is_future || is_too_old {
mask_data.push(min_dtype);
} else {
mask_data.push(0.0);
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty close to a where_cond where t: min_dtype and f: 0.0. Would be nice if we didn't have to move cache_position into the cpu for this, but hard to express is_future and is_too_old though.
I'll give it a think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants