diff --git a/src/parallax/metal/indexer/kernel.py b/src/parallax/metal/indexer/kernel.py index 39ff2c48..32f16287 100644 --- a/src/parallax/metal/indexer/kernel.py +++ b/src/parallax/metal/indexer/kernel.py @@ -57,7 +57,6 @@ def store_indexer_cache( block_tables: mx.array, context_lengths: mx.array, block_size: int, - layer_idx: int, slot_mapping: Optional[mx.array] = None, ): dtype = key.dtype @@ -114,7 +113,6 @@ def mk_int(val): mk_int(num_heads), mk_int(head_dim), mk_int(block_size), - mk_int(layer_idx), mk_int(num_layers), mk_int(num_blocks), ] @@ -127,7 +125,6 @@ def mk_int(val): "num_heads", "head_dim", "block_size", - "layer_idx", "num_layers", "num_blocks", ] @@ -160,7 +157,6 @@ def q_dot_k( block_table: mx.array, # (max_blocks) context_length: mx.array, # scalar block_size: int, - layer_idx: int, ) -> mx.array: if q.ndim > 2: @@ -186,7 +182,6 @@ def mk_int(val): mk_int(block_size), mk_int(num_heads), mk_int(head_dim), - mk_int(layer_idx), mk_int(num_layers), mk_int(num_total_blocks), mk_int(max_blocks), @@ -200,7 +195,6 @@ def mk_int(val): "block_size", "num_heads", "head_dim", - "layer_idx", "num_layers", "num_total_blocks", "max_blocks", diff --git a/src/parallax/metal/indexer/q_dot_k.metal b/src/parallax/metal/indexer/q_dot_k.metal index c1d17f6d..00cb22d4 100644 --- a/src/parallax/metal/indexer/q_dot_k.metal +++ b/src/parallax/metal/indexer/q_dot_k.metal @@ -1,6 +1,6 @@ // Inputs provided by MLX wrapper: // q, key_cache, block_table, output (pointers) -// context_len, block_size, num_heads, head_dim, layer_idx, num_layers, num_total_blocks, max_blocks (scalars) +// context_len, block_size, num_heads, head_dim, num_layers, num_total_blocks, max_blocks (scalars) uint3 gid = thread_position_in_grid; @@ -15,16 +15,13 @@ long k_block_stride = num_heads * block_size * head_dim; long k_head_stride = block_size * head_dim; long k_layer_stride = (long)num_total_blocks * k_block_stride; -long layer_offset = (long)layer_idx * k_layer_stride; - for (int b = 0; b < num_valid_blocks; b++) { int block_num = block_table[b]; int logical_idx = b * block_size + token_in_block; if (logical_idx >= context_len) continue; - long k_base = layer_offset + - (long)block_num * k_block_stride + + long k_base = (long)block_num * k_block_stride + head_idx * k_head_stride + token_in_block * head_dim; diff --git a/src/parallax/metal/indexer/store_key.metal b/src/parallax/metal/indexer/store_key.metal index 8b6ca5fc..d58029f9 100644 --- a/src/parallax/metal/indexer/store_key.metal +++ b/src/parallax/metal/indexer/store_key.metal @@ -1,6 +1,6 @@ // Inputs provided by MLX wrapper: // key, key_cache, slot_mapping (pointers) -// key_stride, num_heads, head_dim, block_size, layer_idx, num_layers, num_blocks (scalars) +// key_stride, num_heads, head_dim, block_size, num_layers, num_blocks (scalars) device {{T}} *key_cache_mut = (device {{T}} *)key_cache; @@ -27,8 +27,7 @@ long k_head_stride = block_size * head_dim; long k_layer_stride = (long)num_blocks * k_block_stride; -long dest_idx = (long)layer_idx * k_layer_stride + - block_idx * k_block_stride + +long dest_idx = block_idx * k_block_stride + head_idx * k_head_stride + block_offset * head_dim + d_idx; diff --git a/src/parallax/metal/paged_attention/kernel.py b/src/parallax/metal/paged_attention/kernel.py index 9a9c4da4..a0302d55 100644 --- a/src/parallax/metal/paged_attention/kernel.py +++ b/src/parallax/metal/paged_attention/kernel.py @@ -63,7 +63,6 @@ def reshape_and_cache( block_tables: mx.array, # (batch, max_blocks) context_lengths: mx.array, # (batch,) block_size: int, - layer_idx: int, slot_mapping: Optional[mx.array] = None, # (batch,) or (batch * target_len,) ): """ @@ -130,25 +129,14 @@ def reshape_and_cache( if slot_mapping.shape[0] != num_tokens: raise ValueError(f"Slot mapping length {slot_mapping.shape[0]} != tokens {num_tokens}") - num_layers = key_cache.shape[0] - num_blocks = key_cache.shape[1] - # 2. Prepare Constants - key_stride = num_kv_heads * k_head_dim - value_stride = num_kv_heads * v_head_dim - def mk_int(val): return mx.array(val, dtype=mx.int32) - c_key_stride = mk_int(key_stride) - c_val_stride = mk_int(value_stride) c_num_kv = mk_int(num_kv_heads) c_k_head_dim = mk_int(k_head_dim) c_v_head_dim = mk_int(v_head_dim) c_block_size = mk_int(block_size) - c_layer_idx = mk_int(layer_idx) - c_num_layers = mk_int(num_layers) - c_num_blocks = mk_int(num_blocks) # Inputs list inputs = [ @@ -157,15 +145,10 @@ def mk_int(val): key_cache, value_cache, slot_mapping, - c_key_stride, - c_val_stride, c_num_kv, c_k_head_dim, c_v_head_dim, c_block_size, - c_layer_idx, - c_num_layers, - c_num_blocks, ] # Input names (just for declaration) @@ -175,15 +158,10 @@ def mk_int(val): "key_cache", "value_cache", "slot_mapping", - "key_stride", - "value_stride", "num_kv_heads", "k_head_dim", "v_head_dim", "block_size", - "layer_idx", - "num_layers", - "num_blocks", ] # 3. Get and Launch Kernel @@ -225,7 +203,6 @@ def paged_attention( block_size: int, scale: float, num_kv_heads: int, - layer_idx: int, v_head_dim: Optional[int] = None, top_k_indices: Optional[mx.array] = None, window_size: Optional[int] = None, @@ -261,7 +238,6 @@ def mk_int(val): c_v_head_dim = mk_int(v_head_dim) c_block_size = mk_int(block_size) c_max_blocks = mk_int(max_blocks) - c_layer_idx = mk_int(layer_idx) c_num_layers = mk_int(num_layers) c_num_total_blocks = mk_int(num_total_blocks) c_scale = mx.array(scale, dtype=queries.dtype) @@ -283,7 +259,6 @@ def mk_int(val): c_v_head_dim, c_block_size, c_max_blocks, - c_layer_idx, c_num_layers, c_num_total_blocks, c_scale, @@ -303,7 +278,6 @@ def mk_int(val): "v_head_dim", "block_size", "max_blocks", - "layer_idx", "num_layers", "num_total_blocks", "scale", @@ -327,7 +301,6 @@ def mk_int(val): c_v_head_dim, c_block_size, c_max_blocks, - c_layer_idx, c_num_layers, c_num_total_blocks, c_scale, @@ -347,7 +320,6 @@ def mk_int(val): "v_head_dim", "block_size", "max_blocks", - "layer_idx", "num_layers", "num_total_blocks", "scale", @@ -368,7 +340,6 @@ def mk_int(val): c_v_head_dim, c_block_size, c_max_blocks, - c_layer_idx, c_num_layers, c_num_total_blocks, c_scale, @@ -386,7 +357,6 @@ def mk_int(val): "v_head_dim", "block_size", "max_blocks", - "layer_idx", "num_layers", "num_total_blocks", "scale", diff --git a/src/parallax/metal/paged_attention/paged_attention.metal b/src/parallax/metal/paged_attention/paged_attention.metal index 527ee6f2..6780a411 100644 --- a/src/parallax/metal/paged_attention/paged_attention.metal +++ b/src/parallax/metal/paged_attention/paged_attention.metal @@ -2,7 +2,7 @@ // Inputs: // queries, key_cache, value_cache, block_tables, context_lengths // output (output array) -// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx, +// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, // num_layers, num_total_blocks, scale (All pointers) uint3 gid = thread_position_in_grid; @@ -24,7 +24,6 @@ int _k_head_dim = k_head_dim; int _v_head_dim = v_head_dim; int _block_size = block_size; int _max_blocks = max_blocks; -int _layer_idx = layer_idx; int _num_total_blocks = num_total_blocks; float _scale = scale; @@ -57,29 +56,21 @@ int context_len = context_lengths[batch_idx]; int num_context_blocks = (context_len + _block_size - 1) / _block_size; // Strides for Key -long k_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim; long k_block_stride = _num_kv_heads * _block_size * _k_head_dim; long k_head_stride = _block_size * _k_head_dim; -long k_layer_offset = _layer_idx * k_layer_stride; - // Strides for Value -long v_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim; long v_block_stride = _num_kv_heads * _block_size * _v_head_dim; long v_head_stride = _block_size * _v_head_dim; -long v_layer_offset = _layer_idx * v_layer_stride; - // Iterate over blocks for (int b = 0; b < num_context_blocks; b++) { int block_num = block_tables[batch_idx * _max_blocks + b]; long k_block_base = - k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride; + block_num * k_block_stride + kv_head_idx * k_head_stride; long v_block_base = - v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride; + block_num * v_block_stride + kv_head_idx * v_head_stride; int tokens_in_block = _block_size; if (b == num_context_blocks - 1) { diff --git a/src/parallax/metal/paged_attention/paged_attention_deepseek_v32.metal b/src/parallax/metal/paged_attention/paged_attention_deepseek_v32.metal index 1823e4d6..900fa104 100644 --- a/src/parallax/metal/paged_attention/paged_attention_deepseek_v32.metal +++ b/src/parallax/metal/paged_attention/paged_attention_deepseek_v32.metal @@ -1,7 +1,7 @@ // Inputs: // queries, key_cache, value_cache, block_tables, context_lengths, top_k_indices // output (output array) -// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx, +// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, // num_layers, num_total_blocks, scale, index_topk (All pointers) uint3 gid = thread_position_in_grid; @@ -23,7 +23,6 @@ int _k_head_dim = k_head_dim; int _v_head_dim = v_head_dim; int _block_size = block_size; int _max_blocks = max_blocks; -int _layer_idx = layer_idx; int _num_total_blocks = num_total_blocks; float _scale = scale; int _index_topk = index_topk; @@ -56,21 +55,13 @@ float acc_vec[8] = {0.0f}; int context_len = context_lengths[batch_idx]; // Strides for Key -long k_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim; long k_block_stride = _num_kv_heads * _block_size * _k_head_dim; long k_head_stride = _block_size * _k_head_dim; -long k_layer_offset = _layer_idx * k_layer_stride; - // Strides for Value -long v_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim; long v_block_stride = _num_kv_heads * _block_size * _v_head_dim; long v_head_stride = _block_size * _v_head_dim; -long v_layer_offset = _layer_idx * v_layer_stride; - // Check if we use Full Attention or Sparse Attention // We check the first element of top_k_indices for this batch int first_topk_idx = top_k_indices[batch_idx * _index_topk]; @@ -85,9 +76,9 @@ if (first_topk_idx == -1) { int block_num = block_tables[batch_idx * _max_blocks + b]; long k_block_base = - k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride; + block_num * k_block_stride + kv_head_idx * k_head_stride; long v_block_base = - v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride; + block_num * v_block_stride + kv_head_idx * v_head_stride; int tokens_in_block = _block_size; if (b == num_context_blocks - 1) { @@ -140,9 +131,9 @@ if (first_topk_idx == -1) { int block_num = block_tables[batch_idx * _max_blocks + block_idx_in_table]; long k_block_base = - k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride; + block_num * k_block_stride + kv_head_idx * k_head_stride; long v_block_base = - v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride; + block_num * v_block_stride + kv_head_idx * v_head_stride; // Compute Dot Product Q * K[token_idx] float score = 0.0f; diff --git a/src/parallax/metal/paged_attention/paged_attention_gpt_oss.metal b/src/parallax/metal/paged_attention/paged_attention_gpt_oss.metal index 0f331b6e..4a63d62a 100644 --- a/src/parallax/metal/paged_attention/paged_attention_gpt_oss.metal +++ b/src/parallax/metal/paged_attention/paged_attention_gpt_oss.metal @@ -2,7 +2,7 @@ // Inputs: // queries, key_cache, value_cache, block_tables, context_lengths, sinks // output (output array) -// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx, +// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, // num_layers, num_total_blocks, scale, window_size (All pointers) uint3 gid = thread_position_in_grid; @@ -24,7 +24,6 @@ int _k_head_dim = k_head_dim; int _v_head_dim = v_head_dim; int _block_size = block_size; int _max_blocks = max_blocks; -int _layer_idx = layer_idx; int _num_total_blocks = num_total_blocks; float _scale = scale; int _window_size = window_size; @@ -78,21 +77,13 @@ int context_len = context_lengths[batch_idx]; int num_context_blocks = (context_len + _block_size - 1) / _block_size; // Strides for Key -long k_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim; long k_block_stride = _num_kv_heads * _block_size * _k_head_dim; long k_head_stride = _block_size * _k_head_dim; -long k_layer_offset = _layer_idx * k_layer_stride; - // Strides for Value -long v_layer_stride = - (long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim; long v_block_stride = _num_kv_heads * _block_size * _v_head_dim; long v_head_stride = _block_size * _v_head_dim; -long v_layer_offset = _layer_idx * v_layer_stride; - // Iterate over blocks for (int b = 0; b < num_context_blocks; b++) { int block_num = block_tables[batch_idx * _max_blocks + b]; @@ -117,9 +108,9 @@ for (int b = 0; b < num_context_blocks; b++) { } long k_block_base = - k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride; + block_num * k_block_stride + kv_head_idx * k_head_stride; long v_block_base = - v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride; + block_num * v_block_stride + kv_head_idx * v_head_stride; int tokens_in_block = _block_size; if (b == num_context_blocks - 1) { diff --git a/src/parallax/metal/paged_attention/reshape_and_cache.metal b/src/parallax/metal/paged_attention/reshape_and_cache.metal index c3acbad7..a9e46f54 100644 --- a/src/parallax/metal/paged_attention/reshape_and_cache.metal +++ b/src/parallax/metal/paged_attention/reshape_and_cache.metal @@ -33,9 +33,6 @@ int b_size = block_size; int block_idx = slot / b_size; int block_offset = slot % b_size; -int l_idx = layer_idx; -int n_blocks = num_blocks; - // Handle Key if (dim_idx < k_dim) { // Calculate source index @@ -46,9 +43,8 @@ if (dim_idx < k_dim) { // Calculate destination index int64_t head_stride = b_size * k_dim; int64_t block_stride = n_kv_heads * head_stride; - int64_t layer_stride = n_blocks * block_stride; - int64_t dest_idx = (int64_t)l_idx * layer_stride + (int64_t)block_idx * block_stride + + int64_t dest_idx = (int64_t)block_idx * block_stride + (int64_t)head_idx * head_stride + block_offset * k_dim + dim_idx; @@ -65,9 +61,8 @@ if (dim_idx < v_dim) { // Calculate destination index int64_t head_stride = b_size * v_dim; int64_t block_stride = n_kv_heads * head_stride; - int64_t layer_stride = n_blocks * block_stride; - int64_t dest_idx = (int64_t)l_idx * layer_stride + (int64_t)block_idx * block_stride + + int64_t dest_idx = (int64_t)block_idx * block_stride + (int64_t)head_idx * head_stride + block_offset * v_dim + dim_idx; diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py index 67903997..fe55ab60 100644 --- a/src/parallax/models/deepseek_v2.py +++ b/src/parallax/models/deepseek_v2.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.deepseek_v2 import ModelArgs from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxDeepSeekV2Attention(MLXDeepseekV2Attention): @@ -24,11 +25,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: """ Attention forward pass with explicit KV cache handling. @@ -36,10 +37,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output_h: (batch, target_len, hidden_dim) - Output hidden states. @@ -63,7 +64,7 @@ def __call__( # q_pe = self.rope(q_pe, offset=offset) # k_pe = self.rope(k_pe, offset=offset) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() q_pe_list = [] k_pe_list = [] for i in range(batch): @@ -93,7 +94,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx=layer_idx, slot_mapping=slot_mapping, ) @@ -107,7 +107,6 @@ def __call__( block_size, self.scale, self.num_heads, # num_kv_heads (MQA/MLA, here num_heads == num_kv_heads effectively after repeat?) - layer_idx, v_head_dim=values.shape[-1], ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) @@ -130,16 +129,17 @@ class ParallaxDeepSeekV2Block(MLXDeepseekV2Block): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx=layer_idx) self.self_attn = ParallaxDeepSeekV2Attention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -148,11 +148,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/deepseek_v3.py b/src/parallax/models/deepseek_v3.py index cc6ecb53..01049ab9 100644 --- a/src/parallax/models/deepseek_v3.py +++ b/src/parallax/models/deepseek_v3.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.deepseek_v3 import ModelArgs from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxDeepSeekV3Attention(MLXDeepseekV3Attention): @@ -24,13 +25,13 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, offset: int = 0, lengths: Optional[mx.array] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: """ Attention forward pass with explicit KV cache handling. @@ -38,10 +39,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output_h: (batch, target_len, hidden_dim) - Output hidden states. @@ -65,7 +66,7 @@ def __call__( # q_pe = self.rope(q_pe, offset=offset) # k_pe = self.rope(k_pe, offset=offset) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() q_pe_list = [] k_pe_list = [] @@ -95,7 +96,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx=layer_idx, slot_mapping=slot_mapping, ) @@ -110,7 +110,6 @@ def __call__( block_size, self.scale, self.num_heads, - layer_idx, v_head_dim=values.shape[-1], ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) @@ -137,16 +136,17 @@ class ParallaxDeepSeekV3Block(MLXDeepseekV3Block): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx=layer_idx) self.self_attn = ParallaxDeepSeekV3Attention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, lengths: Optional[mx.array] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, @@ -156,11 +156,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/deepseek_v32.py b/src/parallax/models/deepseek_v32.py index 373d24a6..162fd047 100644 --- a/src/parallax/models/deepseek_v32.py +++ b/src/parallax/models/deepseek_v32.py @@ -1,5 +1,5 @@ # Copyright © 2025 Apple Inc. -from typing import Any, Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -10,6 +10,7 @@ from parallax.metal.indexer.kernel import q_dot_k, store_indexer_cache from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxDeepSeekV32Indexer(MLXDeepseekV32Indexer): @@ -22,8 +23,8 @@ def __call__( block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, block_size: int = 1024, - layer_idx: int = 0, slot_mapping: Optional[mx.array] = None, + **kwargs, ): # Computes top_k indices for attention batch, target_len, _ = x.shape @@ -56,7 +57,6 @@ def __call__( block_tables, context_lengths, block_size=block_size, - layer_idx=layer_idx, slot_mapping=slot_mapping, ) @@ -73,7 +73,6 @@ def __call__( block_size=block_size, block_table=block_tables[i], context_length=context_lengths[i], - layer_idx=layer_idx, ) # shape: (n_heads, context_len) score = score[:, None, :] # shape: (n_heads, 1, context_len) score = mx.maximum(score, 0) @@ -115,12 +114,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, - indexer_cache: Optional[mx.array] = None, + **kwargs, ) -> mx.array: batch, target_len, _ = x.shape @@ -141,7 +139,8 @@ def __call__( k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) k_nope = k_nope.transpose(0, 2, 1, 3) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() + indexer_cache = cache.get_indexer_cache() q_pe_list = [] k_pe_list = [] for i in range(batch): @@ -168,7 +167,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx=layer_idx, slot_mapping=slot_mapping, ) @@ -180,7 +178,6 @@ def __call__( block_tables=block_tables, context_lengths=context_lengths, block_size=block_size, - layer_idx=layer_idx, slot_mapping=slot_mapping, ) @@ -194,7 +191,6 @@ def __call__( block_size, self.scale, self.num_heads, - layer_idx, v_head_dim=values.shape[-1], top_k_indices=topk_indices, ) @@ -223,31 +219,31 @@ def __call__( class ParallaxDeepSeekV32Block(MLXDeepseekV32Block): - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx=layer_idx) self.self_attn = ParallaxDeepSeekV32Attention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Any] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, **kwargs, ): - indexer_cache = kwargs.get("indexer_cache") + r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, - indexer_cache=indexer_cache, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/glm4_moe.py b/src/parallax/models/glm4_moe.py index 0c44a51d..4d94c392 100644 --- a/src/parallax/models/glm4_moe.py +++ b/src/parallax/models/glm4_moe.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -7,6 +7,7 @@ from mlx_lm.models.glm4_moe import ModelArgs from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxGLM4MoeAttention(MLXGLM4MoeAttention): @@ -14,11 +15,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: batch, target_len, _ = x.shape @@ -35,7 +36,7 @@ def __call__( keys_new = keys_new.transpose(0, 2, 1, 3) values_new = values.reshape(batch, target_len, self.n_kv_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -62,7 +63,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -78,7 +78,6 @@ def __call__( block_size, self.scale, self.n_kv_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -98,15 +97,16 @@ def __call__( class ParallaxGLM4MoeBlock(MLXGLM4MoeBlock): - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx) self.self_attn = ParallaxGLM4MoeAttention(args) + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -115,11 +115,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/gpt_oss.py b/src/parallax/models/gpt_oss.py index b6b14a1a..00d8de59 100644 --- a/src/parallax/models/gpt_oss.py +++ b/src/parallax/models/gpt_oss.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import TransformerBlock as MLXGPTOSSBlock from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxGPTOSSAttention(MLXGPTOSSAttention): @@ -24,12 +25,12 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, window_size: Optional[int] = None, + **kwargs, ) -> mx.array: """ Attention forward pass with PagedAttention integration. @@ -48,7 +49,7 @@ def __call__( ) values_new = values_new.reshape(batch, target_len, self.num_key_value_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -75,7 +76,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -90,7 +90,6 @@ def __call__( block_size, self.sm_scale, self.num_key_value_heads, - layer_idx, window_size=window_size, sinks=self.sinks, ) @@ -121,11 +120,12 @@ class ParallaxGPTOSSBlock(MLXGPTOSSBlock): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args) self.self_attn = ParallaxGPTOSSAttention(args) self.sliding_window = args.sliding_window self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx if args.layer_types: self.layer_type = args.layer_types[layer_idx] else: @@ -138,7 +138,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -153,12 +153,12 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask=mask, - cache=cache, + cache=cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, window_size=window_size, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/llama.py b/src/parallax/models/llama.py index c35e9381..c4320476 100644 --- a/src/parallax/models/llama.py +++ b/src/parallax/models/llama.py @@ -6,7 +6,7 @@ `ShardedModel` can drive it uniformly. """ -from typing import Optional, Tuple +from typing import Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -15,6 +15,7 @@ from mlx_lm.models.llama import TransformerBlock as MLXLlamaBlock from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxLlamaAttention(MLXLlamaAttention): @@ -28,7 +29,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -40,10 +41,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output: (batch, target_len, hidden_dim) - Output hidden states. @@ -58,7 +59,7 @@ def __call__( keys_new = keys_new.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values_new = values_new.reshape(batch, target_len, self.n_kv_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -85,7 +86,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -101,7 +101,6 @@ def __call__( block_size, self.scale, self.n_kv_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -122,16 +121,17 @@ def __call__( class ParallaxLlamaBlock(MLXLlamaBlock): """Transformer block wrapper returning explicit KV cache updates.""" - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args) self.self_attn = ParallaxLlamaAttention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -140,11 +140,10 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 71f29e24..287c144c 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -1,6 +1,6 @@ # Copyright © 2025 Apple Inc. -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -9,6 +9,7 @@ from mlx_lm.models.minimax import ModelArgs from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxMiniMaxAttention(MLXMiniMaxAttention): @@ -17,11 +18,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: batch, target_len, _ = x.shape @@ -42,7 +43,7 @@ def __call__( ) values_new = values.reshape(batch, target_len, self.num_key_value_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -69,7 +70,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -85,7 +85,6 @@ def __call__( block_size, self.scale, self.num_key_value_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -108,16 +107,17 @@ class ParallaxMiniMaxBlock(MLXMiniMaxBlock): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args) self.self_attn = ParallaxMiniMaxAttention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -126,11 +126,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.block_sparse_moe(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/qwen2.py b/src/parallax/models/qwen2.py index a2bab4b4..12d2f1ca 100644 --- a/src/parallax/models/qwen2.py +++ b/src/parallax/models/qwen2.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.qwen2 import TransformerBlock as MLXQwen2Block from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxQwen2Attention(MLXQwen2Attention): @@ -24,11 +25,10 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, ) -> mx.array: """ Attention forward pass with explicit KV cache handling. @@ -36,10 +36,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output: (batch, target_len, hidden_dim) - Output hidden states. @@ -54,7 +54,7 @@ def __call__( keys_new = keys_new.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values_new = values_new.reshape(batch, target_len, self.n_kv_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -81,7 +81,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -97,7 +96,6 @@ def __call__( block_size, self.scale, self.n_kv_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -120,16 +118,17 @@ class ParallaxQwen2Block(MLXQwen2Block): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args) self.self_attn = ParallaxQwen2Attention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -138,11 +137,10 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/qwen3.py b/src/parallax/models/qwen3.py index 8d8e1078..3e93ea4d 100644 --- a/src/parallax/models/qwen3.py +++ b/src/parallax/models/qwen3.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.qwen3 import TransformerBlock as MLXQwen3Block from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxQwen3Attention(MLXQwen3Attention): @@ -24,11 +25,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: """ Attention forward pass with explicit KV cache handling. @@ -36,10 +37,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output: (batch, target_len, hidden_dim) - Output hidden states. @@ -58,7 +59,7 @@ def __call__( ) values_new = values_new.reshape(batch, target_len, self.n_kv_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -87,7 +88,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -103,7 +103,6 @@ def __call__( block_size, self.scale, self.n_kv_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -126,16 +125,17 @@ class ParallaxQwen3Block(MLXQwen3Block): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args) self.self_attn = ParallaxQwen3Attention(args) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -144,11 +144,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/qwen3_moe.py b/src/parallax/models/qwen3_moe.py index fe3bbf1b..7c51fb93 100644 --- a/src/parallax/models/qwen3_moe.py +++ b/src/parallax/models/qwen3_moe.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention @@ -11,6 +11,7 @@ from mlx_lm.models.qwen3_moe import Qwen3MoeDecoderLayer as MLXQwen3MoeBlock from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache class ParallaxQwen3MoeAttention(MLXQwen3MoeAttention): @@ -24,11 +25,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[BaseCache] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - layer_idx: int = 0, + **kwargs, ) -> mx.array: """ Attention forward pass with explicit KV cache handling. @@ -36,10 +37,10 @@ def __call__( Args: x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. mask: (batch, n_q_heads, target_len, source_len) - cache: contains (key_cache, value_cache) global. + cache: BaseCache object containing the layer cache. block_tables: (batch, max_blocks) - PagedKV block tables. context_lengths: (batch,) - PagedKV sequence lengths. - layer_idx: Layer index for PagedKV access. + slot_mapping: (batch * target_len,) - Flattened slot mapping. Returns: output: (batch, target_len, hidden_dim) - Output hidden states. @@ -58,7 +59,7 @@ def __call__( ) values_new = values_new.reshape(batch, target_len, self.n_kv_heads, -1) - key_cache_global, value_cache_global = cache + key_cache_global, value_cache_global = cache.get_cache() queries_rotated_list = [] keys_rotated_list = [] @@ -85,7 +86,6 @@ def __call__( block_tables, context_lengths, block_size, - layer_idx, slot_mapping=slot_mapping, ) @@ -101,7 +101,6 @@ def __call__( block_size, self.scale, self.n_kv_heads, - layer_idx, ) output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: @@ -124,16 +123,17 @@ class ParallaxQwen3MoeBlock(MLXQwen3MoeBlock): This version handles the KV cache explicitly and returns new K and V states. """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx) self.self_attn = ParallaxQwen3MoeAttention(args, layer_idx) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, @@ -142,11 +142,11 @@ def __call__( r = self.self_attn( self.input_layernorm(x), mask, - cache, + cache[self.local_layer_idx], block_tables=block_tables, context_lengths=context_lengths, slot_mapping=slot_mapping, - layer_idx=self.layer_idx, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) diff --git a/src/parallax/models/qwen3_next.py b/src/parallax/models/qwen3_next.py index ba408865..cc11ee40 100644 --- a/src/parallax/models/qwen3_next.py +++ b/src/parallax/models/qwen3_next.py @@ -2,7 +2,7 @@ hidden_dimefines the Qwen3 model. """ -from typing import Optional, Tuple +from typing import Any, List, Optional import mlx.core as mx import mlx.nn as nn @@ -13,52 +13,23 @@ from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer as MLXQwen3NextBlock from mlx_lm.models.qwen3_next import Qwen3NextGatedDeltaNet as MLXQwen3NextGatedDeltaNet +from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache + class ParallaxQwen3NextAttention(MLXQwen3NextAttention): - """A custom attention module for Parallax, extending the Qwen3 Attention class. - - We apply explicit KV cache handling and passing in `offset` directly from Request. - This version returns the new K and V states for external caching. - """ - - def __init__(self, args: ModelArgs): - super().__init__(args) - self.hidden_size = args.hidden_size - self.num_v_heads = args.linear_num_value_heads - self.num_k_heads = args.linear_num_key_heads - self.head_k_dim = args.linear_key_head_dim - self.head_v_dim = args.linear_value_head_dim - self.conv_kernel_size = args.linear_conv_kernel_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - self.conv_dim = self.key_dim * 2 + self.value_dim def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - offset: int = 0, - state_cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: - """ - Attention forward pass with explicit KV cache handling. - - Args: - x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. - mask: (batch, n_q_heads, target_len, source_len) - cache: Optional tuple (past_k, past_v). - shape: (batch, n_kv_heads, S_past_padded, head_dim) - offset: source_len_padded (scalar, used for RoPE calculation). - - Returns: - output_h: (batch, target_len, hidden_dim) - Output hidden states. - new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. - new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. - """ + cache: Optional[BaseCache] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: batch, target_len, _ = x.shape - # print("inputs shape:", x.shape) - # print(f"x.value --- IGNORE --- {x}") queries_new = self.q_proj(x) keys_new = self.k_proj(x) @@ -72,86 +43,89 @@ def __call__( keys_new = self.k_norm( keys_new.reshape(batch, target_len, self.num_key_value_heads, -1) ).transpose(0, 2, 1, 3) - values_new = values_new.reshape(batch, target_len, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 + values_new = values_new.reshape(batch, target_len, self.num_key_value_heads, -1) + + key_cache_global, value_cache_global = cache.get_cache() + + queries_rotated_list = [] + keys_rotated_list = [] + for i in range(batch): + current_pos = int(context_lengths[i]) - 1 if target_len == 1 else 0 + q_slice = queries_new[i : i + 1] + k_slice = keys_new[i : i + 1] + q_rot = self.rope(q_slice, offset=current_pos) + k_rot = self.rope(k_slice, offset=current_pos) + queries_rotated_list.append(q_rot) + keys_rotated_list.append(k_rot) + queries_rotated = mx.concatenate(queries_rotated_list, axis=0) + keys_rotated = mx.concatenate(keys_rotated_list, axis=0) + + block_size = key_cache_global.shape[3] + reshape_and_cache( + keys_rotated.transpose(0, 2, 1, 3), + values_new, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, ) - - queries_rotated = self.rope(queries_new, offset=offset) - keys_rotated = self.rope(keys_new, offset=offset) - - if cache is not None: - past_k, past_v = cache - if past_k is not None and past_v is not None: - if past_k.shape[2] != offset: - raise ValueError( - f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " - f"to match RoPE offset {offset} (S_past_padded)." - ) - final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) - final_values_for_attn = mx.concatenate([past_v, values_new], axis=2) - else: - raise ValueError("cache was provided but one of k/v was None.") + if target_len == 1: + output = paged_attention( + queries_rotated, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + self.num_key_value_heads, + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) else: - final_keys_for_attn = keys_rotated - final_values_for_attn = values_new - - output = scaled_dot_product_attention( - queries_rotated, - final_keys_for_attn, - final_values_for_attn, - scale=self.scale, - mask=mask, - cache=None, - ) + output = scaled_dot_product_attention( + queries_rotated, + keys_rotated, + values_new.transpose(0, 2, 1, 3), + scale=self.scale, + mask=mask, + cache=None, + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) - output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) - - return self.o_proj(output * mx.sigmoid(gate)), ( - keys_rotated, - values_new, - ( - state_cache[0] - if (state_cache is not None) - else mx.zeros((batch, self.conv_kernel_size - 1, self.conv_dim), dtype=x.dtype) - ), - ( - state_cache[1] - if (state_cache is not None) - else mx.zeros( - (batch, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=x.dtype - ) - ), - ) + return self.o_proj(output * mx.sigmoid(gate)) class ParallaxQwen3NextGatedDeltaNet(MLXQwen3NextGatedDeltaNet): - def __init__(self, args: ModelArgs): - super().__init__(args) - self.num_key_value_heads = args.num_key_value_heads - self.head_dim = args.head_dim - def __call__( self, - inputs, - cache: Optional[Tuple[mx.array, mx.array]] = None, - state_cache: Optional[Tuple[mx.array, mx.array]] = None, + x: mx.array, + cache: Optional[BaseCache] = None, + state_slot_mapping: Optional[mx.array] = None, + **kwargs, ): - B, S, _ = inputs.shape - # print(f"inputs.value --- IGNORE --- {inputs}") + batch, target_len, _ = x.shape q, k, v, z, b, a = self.fix_query_key_value_ordering( - self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) + self.in_proj_qkvz(x), self.in_proj_ba(x) ) - if state_cache is not None and state_cache[0] is not None: - conv_state = state_cache[0] + if target_len == 1: + conv_state, state1 = cache.read_states(state_slot_mapping) else: conv_state = mx.zeros( - (B, self.conv_kernel_size - 1, self.conv_dim), - dtype=inputs.dtype, + (batch, self.conv_kernel_size - 1, self.conv_dim), + dtype=x.dtype, ) + state1 = None mixed_qkv = mx.concatenate( - [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], axis=-1 + [ + q.reshape(batch, target_len, -1), + k.reshape(batch, target_len, -1), + v.reshape(batch, target_len, -1), + ], + axis=-1, ) conv_input = mx.concatenate([conv_state, mixed_qkv], axis=1) @@ -159,49 +133,31 @@ def __call__( conv_out = nn.silu(self.conv1d(conv_input)) q, k, v = [ - t.reshape(B, S, h, d) + t.reshape(batch, target_len, h, d) for t, h, d in zip( mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), [self.num_k_heads, self.num_k_heads, self.num_v_heads], [self.head_k_dim, self.head_k_dim, self.head_v_dim], ) ] - if state_cache is not None: - state1 = state_cache[1] - else: - state1 = None inv_scale = k.shape[-1] ** -0.5 q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) - out, state1 = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state1) - out = self.norm(out, z) - return self.out_proj(out.reshape(B, S, -1)), ( - ( - cache[0][..., :S, :] - if cache is not None - else mx.zeros((B, self.num_key_value_heads, S, self.head_dim), dtype=inputs.dtype) - ), - ( - cache[1][..., :S, :] - if cache is not None - else mx.zeros((B, self.num_key_value_heads, S, self.head_dim), dtype=inputs.dtype) - ), - state0, - state1, - ) + + cache.write_states(state_slot_mapping, state0, state1) + + return self.out_proj(out.reshape(batch, target_len, -1)) class ParallaxQwen3NextBlock(MLXQwen3NextBlock): - """A custom transformer block for Parallax, extending the Qwen3 Block class. - This version handles the KV cache explicitly and returns new K and V states. - """ - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): super().__init__(args, layer_idx) self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx if self.is_linear: self.linear_attn = ParallaxQwen3NextGatedDeltaNet(args) else: @@ -211,24 +167,31 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - offset: int = 0, - lengths: Optional[mx.array] = None, - state_cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, **kwargs, ): if self.is_linear: - r, (k_cache, v_cache, state0, state1) = self.linear_attn( - self.input_layernorm(x), cache, state_cache + state_slot_mapping = kwargs.pop("state_slot_mapping", None) + r = self.linear_attn( + self.input_layernorm(x), cache[self.local_layer_idx], state_slot_mapping, **kwargs ) else: - r, (k_cache, v_cache, state0, state1) = self.self_attn( - self.input_layernorm(x), mask, cache, offset, state_cache + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, ) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, (k_cache, v_cache, state0, state1) + return out @classmethod def get_architecture(cls): @@ -236,4 +199,4 @@ def get_architecture(cls): return "Qwen3NextForCausalLM" -# EntryClass = ParallaxQwen3NextBlock +EntryClass = ParallaxQwen3NextBlock diff --git a/src/parallax/server/cache/allocator.py b/src/parallax/server/cache/allocator.py new file mode 100644 index 00000000..b0698750 --- /dev/null +++ b/src/parallax/server/cache/allocator.py @@ -0,0 +1,71 @@ +from typing import List, Set + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class BlockAllocator: + """Manages allocation of physical block indices.""" + + def __init__(self, num_blocks: int, block_size: int): + self.num_blocks = num_blocks + self.block_size = block_size + # Initialize free blocks stack + self.free_blocks: List[int] = list(range(num_blocks)) + self.used_blocks: Set[int] = set() + + def allocate(self, num_blocks_needed: int) -> List[int]: + """Allocates `num_blocks_needed` physical blocks.""" + if len(self.free_blocks) < num_blocks_needed: + return [] + + # Pop blocks from the stack + split_idx = len(self.free_blocks) - num_blocks_needed + allocated = self.free_blocks[split_idx:] + self.free_blocks = self.free_blocks[:split_idx] + + for b in allocated: + self.used_blocks.add(b) + + return allocated + + def free(self, blocks: List[int]): + """Frees the given physical blocks.""" + for b in blocks: + if b in self.used_blocks: + self.used_blocks.remove(b) + self.free_blocks.append(b) + else: + logger.warning(f"Double free detected for block {b}") + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + +class SlotAllocator: + """Manages allocation of request slots (indices).""" + + def __init__(self, num_slots: int): + self.num_slots = num_slots + self.free_slots: List[int] = list(range(num_slots)) + self.used_slots: Set[int] = set() + + def allocate(self) -> int: + """Allocates a single slot.""" + if not self.free_slots: + return -1 + slot = self.free_slots.pop() + self.used_slots.add(slot) + return slot + + def free(self, slot: int): + """Frees the given slot.""" + if slot in self.used_slots: + self.used_slots.remove(slot) + self.free_slots.append(slot) + else: + logger.warning(f"Double free detected for slot {slot}") + + def get_num_free_slots(self) -> int: + return len(self.free_slots) diff --git a/src/parallax/server/cache/base.py b/src/parallax/server/cache/base.py new file mode 100644 index 00000000..9535978f --- /dev/null +++ b/src/parallax/server/cache/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseCache(ABC): + """Abstract base class for layer-level cache.""" + + @abstractmethod + def get_cache(self) -> Any: + pass diff --git a/src/parallax/server/cache/dsa_cache.py b/src/parallax/server/cache/dsa_cache.py new file mode 100644 index 00000000..8e352f68 --- /dev/null +++ b/src/parallax/server/cache/dsa_cache.py @@ -0,0 +1,38 @@ +from typing import Optional + +import mlx.core as mx + +from parallax.server.cache.kv_cache import KVCache + + +class DeepSeekSparseCache(KVCache): + """ + KVCache with additional indexer cache for DeepSeek Sparse Attention. + """ + + def __init__( + self, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_dim: int, + head_dim_v: int, + dtype: mx.Dtype, + indexer_key_head_dim: int, + indexer_num_kv_heads: int, + ): + super().__init__(num_blocks, block_size, num_kv_heads, head_dim, head_dim_v, dtype) + self.indexer_key_cache = mx.zeros( + ( + 1, + num_blocks, + indexer_num_kv_heads, + block_size, + indexer_key_head_dim, + ), + dtype=dtype, + ) + mx.eval(self.indexer_key_cache) + + def get_indexer_cache(self) -> Optional[mx.array]: + return self.indexer_key_cache diff --git a/src/parallax/server/cache/kv_cache.py b/src/parallax/server/cache/kv_cache.py new file mode 100644 index 00000000..dba9c45a --- /dev/null +++ b/src/parallax/server/cache/kv_cache.py @@ -0,0 +1,37 @@ +from typing import Tuple + +import mlx.core as mx + +from parallax.server.cache.base import BaseCache + + +class KVCache(BaseCache): + """ + Standard Paged KV Cache for a single layer. + Shape: (1, num_blocks, num_kv_heads, block_size, head_dim) + """ + + def __init__( + self, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_dim: int, + head_dim_v: int, + dtype: mx.Dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.head_dim_v = head_dim_v + self.dtype = dtype + + self.key_cache = mx.zeros((1, num_blocks, num_kv_heads, block_size, head_dim), dtype=dtype) + self.value_cache = mx.zeros( + (1, num_blocks, num_kv_heads, block_size, head_dim_v), dtype=dtype + ) + mx.eval(self.key_cache, self.value_cache) + + def get_cache(self) -> Tuple[mx.array, mx.array]: + return self.key_cache, self.value_cache diff --git a/src/parallax/server/cache/linear_cache.py b/src/parallax/server/cache/linear_cache.py new file mode 100644 index 00000000..fbd14bc9 --- /dev/null +++ b/src/parallax/server/cache/linear_cache.py @@ -0,0 +1,89 @@ +from typing import Optional, Tuple + +import mlx.core as mx + +from parallax.server.cache.base import BaseCache + + +class LinearCache(BaseCache): + + def __init__( + self, + max_num_seqs: int = 128, + conv_dim: Optional[int] = None, + conv_kernel_size: Optional[int] = None, + linear_k_dim: Optional[int] = None, + linear_v_dim: Optional[int] = None, + linear_num_k_heads: Optional[int] = None, + linear_num_v_heads: Optional[int] = None, + dtype: mx.Dtype = mx.float16, + ): + self.max_num_seqs = max_num_seqs + self.dtype = dtype + + self.conv_state_cache = None + self.linear_state_cache = None + + if conv_dim is not None and conv_kernel_size is not None: + conv_state_len = conv_kernel_size - 1 + self.conv_state_cache = mx.zeros( + (1, max_num_seqs, conv_state_len, conv_dim), dtype=dtype + ) + mx.eval(self.conv_state_cache) + + if ( + linear_k_dim is not None + and linear_v_dim is not None + and linear_num_k_heads is not None + and linear_num_v_heads is not None + ): + self.linear_state_cache = mx.zeros( + ( + 1, + max_num_seqs, + linear_num_v_heads, + linear_v_dim, + linear_k_dim, + ), + dtype=dtype, + ) + mx.eval(self.linear_state_cache) + + def get_cache(self) -> Tuple[Optional[mx.array], Optional[mx.array]]: + return self.conv_state_cache, self.linear_state_cache + + def get_indexer_cache(self) -> Optional[mx.array]: + return None + + def read_states(self, slot_mapping: mx.array) -> Tuple[Optional[mx.array], Optional[mx.array]]: + conv_state_list = [] + linear_state_list = [] + + for slot_idx in slot_mapping: + slot_idx = int(slot_idx) + if self.conv_state_cache is not None: + conv_state_slice = self.conv_state_cache[0, slot_idx] + conv_state_list.append(conv_state_slice[None, :, :]) + + if self.linear_state_cache is not None: + linear_state_slice = self.linear_state_cache[0, slot_idx] + linear_state_list.append(linear_state_slice[None, :, :, :]) + + conv_states = mx.concatenate(conv_state_list, axis=0) if conv_state_list else None + linear_states = mx.concatenate(linear_state_list, axis=0) if linear_state_list else None + + return conv_states, linear_states + + def write_states( + self, + slot_mapping: mx.array, + conv_states: mx.array, + linear_states: Optional[mx.array], + ): + for i, slot_idx in enumerate(slot_mapping): + slot_idx = int(slot_idx) + if self.conv_state_cache is not None: + self.conv_state_cache[0, slot_idx] = conv_states[i] + + if self.linear_state_cache is not None and linear_states is not None: + self.linear_state_cache[0, slot_idx] = linear_states[i] diff --git a/src/parallax/server/cache_manager.py b/src/parallax/server/cache_manager.py new file mode 100644 index 00000000..090e6d11 --- /dev/null +++ b/src/parallax/server/cache_manager.py @@ -0,0 +1,340 @@ +from typing import Dict, List, Optional + +import mlx.core as mx + +from parallax.server.cache.allocator import BlockAllocator, SlotAllocator +from parallax.server.cache.base import BaseCache +from parallax.server.cache.dsa_cache import DeepSeekSparseCache +from parallax.server.cache.kv_cache import KVCache +from parallax.server.cache.linear_cache import LinearCache +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class CacheManager: + """ + Manages the Layer Caches (KV and Linear) and their memory allocation for requests. + Supports hybrid models with mix of Attention and Linear layers. + """ + + def __init__( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype: mx.Dtype, + block_size: int = 16, + cache_memory_fraction: float = 0.8, + num_gpu_blocks: Optional[int] = None, + max_num_seqs: int = 256, # Max concurrent requests hint + head_dim_v: Optional[int] = None, + indexer_key_head_dim: Optional[int] = None, + indexer_num_kv_heads: Optional[int] = None, + # Hybrid Config: List of 'attention' or 'linear' or None (default 'attention') + layer_types: Optional[List[str]] = None, + # Linear Model / State Cache Params + conv_dim: Optional[int] = None, + conv_kernel_size: Optional[int] = None, + linear_k_dim: Optional[int] = None, + linear_v_dim: Optional[int] = None, + linear_num_k_heads: Optional[int] = None, + linear_num_v_heads: Optional[int] = None, + ): + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.indexer_key_head_dim = indexer_key_head_dim + self.indexer_num_kv_heads = indexer_num_kv_heads + self.dtype = dtype + self.block_size = block_size + self.max_num_seqs = max_num_seqs + + # Linear cache params (store for memory calculation) + self.conv_dim = conv_dim + self.conv_kernel_size = conv_kernel_size + self.linear_k_dim = linear_k_dim + self.linear_v_dim = linear_v_dim + self.linear_num_k_heads = linear_num_k_heads + self.linear_num_v_heads = linear_num_v_heads + + # Determine layer types + if layer_types is None: + self.layer_types = ["attention"] * num_layers + else: + assert len(layer_types) == num_layers, "layer_types length must match num_layers" + self.layer_types = layer_types + + # Check if we need blocks (any attention layer) and slots (any linear layer) + self.needs_blocks = any(t == "attention" for t in self.layer_types) + self.needs_slots = any(t == "linear" for t in self.layer_types) + + if num_gpu_blocks is None and self.needs_blocks: + num_gpu_blocks = self._calculate_num_blocks(cache_memory_fraction, dtype) + elif not self.needs_blocks: + num_gpu_blocks = 0 + + self.num_gpu_blocks = num_gpu_blocks + + # 1. Initialize Allocators + self.allocator = BlockAllocator(num_gpu_blocks, block_size) if self.needs_blocks else None + self.slot_allocator = SlotAllocator(max_num_seqs) if self.needs_slots else None + + # 2. Initialize Layer Caches + self.caches: List[BaseCache] = [] + + for layer_type in self.layer_types: + self.caches.append(self._create_cache(layer_type)) + + if self.needs_blocks: + logger.info( + f"Allocated Paged KV Cache for {self.layer_types.count('attention')} layers: " + f"{num_gpu_blocks} blocks, {block_size} block_size" + ) + if self.needs_slots: + logger.info( + f"Allocated Linear State Cache for {self.layer_types.count('linear')} layers: " + f"{max_num_seqs} max slots" + ) + + # 3. Request State Management + # Mapping: request_id -> List of physical block indices + self.block_tables: Dict[str, List[int]] = {} + # Mapping: request_id -> current context length (number of tokens) + self.context_lengths: Dict[str, int] = {} + # Mapping: request_id -> state slot index + self.request_slots: Dict[str, int] = {} + + def _create_cache(self, layer_type: str) -> BaseCache: + if layer_type == "attention": + if self.indexer_key_head_dim is not None and self.indexer_num_kv_heads is not None: + return DeepSeekSparseCache( + num_blocks=self.num_gpu_blocks, + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + head_dim_v=self.head_dim_v, + dtype=self.dtype, + indexer_key_head_dim=self.indexer_key_head_dim, + indexer_num_kv_heads=self.indexer_num_kv_heads, + ) + else: + return KVCache( + num_blocks=self.num_gpu_blocks, + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + head_dim_v=self.head_dim_v, + dtype=self.dtype, + ) + + elif layer_type == "linear": + # We assume uniform linear config for all linear layers for now + return LinearCache( + max_num_seqs=self.max_num_seqs, + conv_dim=self.conv_dim, + conv_kernel_size=self.conv_kernel_size, + linear_k_dim=self.linear_k_dim, + linear_v_dim=self.linear_v_dim, + linear_num_k_heads=self.linear_num_k_heads, + linear_num_v_heads=self.linear_num_v_heads, + dtype=self.dtype, + ) + else: + raise ValueError(f"Unknown layer type: {layer_type}") + + def _calculate_linear_cache_bytes(self, dtype_size: int) -> int: + """Calculate total memory needed for linear cache across all linear layers.""" + num_linear_layers = self.layer_types.count("linear") + if num_linear_layers == 0: + return 0 + + one_layer_bytes = 0 + + # conv_state: (1, max_num_seqs, conv_kernel_size - 1, conv_dim) + if self.conv_dim is not None and self.conv_kernel_size is not None: + conv_state_len = self.conv_kernel_size - 1 + one_layer_bytes += self.max_num_seqs * conv_state_len * self.conv_dim * dtype_size + + # linear_state: (1, max_num_seqs, linear_num_v_heads, linear_v_dim, linear_k_dim) + if ( + self.linear_k_dim is not None + and self.linear_v_dim is not None + and self.linear_num_v_heads is not None + ): + one_layer_bytes += ( + self.max_num_seqs + * self.linear_num_v_heads + * self.linear_v_dim + * self.linear_k_dim + * dtype_size + ) + + total_bytes = one_layer_bytes * num_linear_layers + + if total_bytes > 0: + logger.info( + f"Linear cache will use {total_bytes / 1024**3:.2f} GB " + f"for {num_linear_layers} layers" + ) + + return total_bytes + + def _calculate_num_blocks(self, cache_memory_fraction: float, dtype: mx.Dtype) -> int: + device_info = mx.metal.device_info() + total_mem = device_info["max_recommended_working_set_size"] + current_mem = mx.metal.get_active_memory() + free_mem = total_mem - current_mem + available_for_cache = free_mem * cache_memory_fraction + + dtype_size = 2 if dtype in [mx.float16, mx.bfloat16] else 4 + + # First, calculate linear cache memory (fixed size, allocated upfront) + linear_cache_bytes = self._calculate_linear_cache_bytes(dtype_size) + + # Remaining memory for KV cache + available_for_kv = available_for_cache - linear_cache_bytes + if available_for_kv <= 0: + logger.warning("Linear cache uses all available memory. No room for KV cache blocks.") + return 0 + + # Calculate bytes per block for ONE attention layer + one_layer_block_bytes = ( + self.num_kv_heads * self.block_size * (self.head_dim + self.head_dim_v) * dtype_size + ) + if self.indexer_key_head_dim is not None and self.indexer_num_kv_heads is not None: + one_layer_block_bytes += ( + self.indexer_num_kv_heads * self.block_size * self.indexer_key_head_dim * dtype_size + ) + + # Total bytes per block = Sum over all attention layers + num_attention_layers = self.layer_types.count("attention") + total_block_bytes = one_layer_block_bytes * num_attention_layers + + if total_block_bytes == 0: + return 0 + + num_gpu_blocks = int(available_for_kv // total_block_bytes) + + if num_gpu_blocks <= 0: + logger.warning("Not enough memory for KV cache. Defaulting to 16 blocks.") + num_gpu_blocks = 16 + + logger.info( + f"KV cache will use {num_gpu_blocks * total_block_bytes / 1024**3:.2f} GB " + f"for {num_attention_layers} layers ({num_gpu_blocks} blocks)" + ) + + return num_gpu_blocks + + def can_allocate(self, num_tokens: int) -> bool: + if not self.needs_blocks: + return ( + self.slot_allocator.get_num_free_slots() > 0 if self.needs_slots else True + ) # Should check slots + + num_blocks = (num_tokens + self.block_size - 1) // self.block_size + blocks_ok = self.allocator.get_num_free_blocks() >= num_blocks + + slots_ok = True + if self.needs_slots: + slots_ok = self.slot_allocator.get_num_free_slots() > 0 + + return blocks_ok and slots_ok + + def allocate_request(self, request_id: str, prompt_len: int) -> bool: + if request_id in self.block_tables: + return True + + # 1. Allocate Slot (if needed) + slot = -1 + if self.needs_slots: + slot = self.slot_allocator.allocate() + if slot == -1: + return False + + # 2. Allocate Blocks (if needed) + blocks = [] + if self.needs_blocks: + num_blocks = (prompt_len + self.block_size - 1) // self.block_size + blocks = self.allocator.allocate(num_blocks) + if len(blocks) < num_blocks: + if blocks: + self.allocator.free(blocks) + if slot != -1: + self.slot_allocator.free(slot) + return False + + # 3. Commit + if self.needs_blocks: + self.block_tables[request_id] = blocks + self.context_lengths[request_id] = prompt_len + + if self.needs_slots: + self.request_slots[request_id] = slot + # Zero out state caches for this slot + for cache in self.caches: + if isinstance(cache, LinearCache): + # Zero out conv and linear states + if cache.conv_state_cache is not None: + cache.conv_state_cache[..., slot, :, :] = 0 + if cache.linear_state_cache is not None: + cache.linear_state_cache[..., slot, :, :, :] = 0 + + return True + + def free_request(self, request_id: str): + if self.needs_blocks and request_id in self.block_tables: + blocks = self.block_tables[request_id] + self.allocator.free(blocks) + del self.block_tables[request_id] + if request_id in self.context_lengths: + del self.context_lengths[request_id] + + if self.needs_slots and request_id in self.request_slots: + slot = self.request_slots[request_id] + self.slot_allocator.free(slot) + del self.request_slots[request_id] + + def release_request(self, request_id: str): + self.free_request(request_id) + + def has_request(self, request_id: str) -> bool: + if self.needs_blocks: + return request_id in self.block_tables + if self.needs_slots: + return request_id in self.request_slots + return False + + def append_slot(self, request_id: str) -> bool: + """Decode step allocation.""" + if not self.needs_blocks: + # Linear layers don't grow context + return True + + if request_id not in self.block_tables: + raise ValueError(f"Request {request_id} not found") + + current_len = self.context_lengths[request_id] + if current_len % self.block_size == 0: + new_blocks = self.allocator.allocate(1) + if not new_blocks: + return False + self.block_tables[request_id].extend(new_blocks) + + self.context_lengths[request_id] += 1 + return True + + def get_block_table(self, request_id: str) -> List[int]: + return self.block_tables.get(request_id, []) + + def get_context_length(self, request_id: str) -> int: + return self.context_lengths.get(request_id, 0) + + def get_slot(self, request_id: str) -> int: + return self.request_slots.get(request_id, -1) + + def get_caches(self) -> List[BaseCache]: + """Returns the list of layer caches.""" + return self.caches diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a092ac1e..bc80e1f6 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -147,7 +147,7 @@ def __init__( is_first_peer=self.is_first_peer, tokenizer=self.tokenizer, eos_token_id=self.eos_token_id, - kv_cache_manager=self.kv_cache_manager if self.device == "mlx" else None, + cache_manager=self.cache_manager if self.device == "mlx" else None, request_timeout_s=request_timeout_s, shared_state=self.shared_state, ) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 67c14ca7..6e39d7c0 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -7,8 +7,8 @@ import mlx.core as mx +from parallax.server.cache_manager import CacheManager from parallax.server.executor.base_executor import BaseExecutor -from parallax.server.paged_kv_cache import PagedKVCacheManager from parallax.server.request import ( InitialRequest, IntermediateRequest, @@ -21,6 +21,7 @@ combine_padding_and_causal_masks, create_causal_mask, get_device_dtype, + get_layer_types, pad_inputs, ) from parallax_utils.logging_config import get_logger @@ -135,17 +136,19 @@ def __init__( value_dim = linear_value_head_dim * linear_num_value_heads if key_dim is not None and value_dim is not None: conv_dim = key_dim * 2 + value_dim - self.using_state_cache = linear_conv_kernel_dim is not None and conv_dim is not None indexer_key_head_dim = self.config.get("indexer_key_head_dim", None) indexer_num_kv_heads = self.config.get("indexer_num_kv_heads", None) + layer_types = get_layer_types(self.config, start_layer, end_layer) + logger.debug(f"layer_types: {layer_types}") + time.sleep(5) logger.debug( - "Initializing PagedKVCacheManager (mlx) with block_size=%d, layers=%d", + "Initializing CacheManager (mlx) with block_size=%d, layers=%d", kv_block_size, self.num_shard_layers, ) - self.kv_cache_manager = PagedKVCacheManager( + self.cache_manager = CacheManager( num_layers=self.num_shard_layers, num_kv_heads=num_key_value_heads, head_dim=head_dim, @@ -155,6 +158,14 @@ def __init__( head_dim_v=v_head_dim, indexer_key_head_dim=indexer_key_head_dim, indexer_num_kv_heads=indexer_num_kv_heads, + layer_types=layer_types, + max_num_seqs=max_batch_size // micro_batch_ratio, + conv_dim=conv_dim, + conv_kernel_size=linear_conv_kernel_dim, + linear_k_dim=linear_key_head_dim, + linear_v_dim=linear_value_head_dim, + linear_num_k_heads=linear_num_key_heads, + linear_num_v_heads=linear_num_value_heads, ) super().__init__( start_layer=start_layer, @@ -195,7 +206,7 @@ def __init__( # ) logger.debug( - f"KVCacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}" + f"CacheManager ready; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}" ) def handle_input_requests(self, requests: List[Request]): @@ -217,7 +228,7 @@ def handle_input_requests(self, requests: List[Request]): "It might have been cancelled or finished." ) continue - if not self.kv_cache_manager.has_request(req.request_id): + if not self.cache_manager.has_request(req.request_id): logger.warning( f"Received IntermediateRequest {req.request_id}. " "But no corresponding request found in cache manager. " @@ -232,7 +243,7 @@ def handle_input_requests(self, requests: List[Request]): # Check for termination. if self.scheduler.check_and_update_request_status(original_req): - self.kv_cache_manager.release_request(original_req.request_id) + self.cache_manager.release_request(original_req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " f"memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" @@ -266,14 +277,13 @@ def handle_input_requests(self, requests: List[Request]): ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: if self.enable_prefix_cache: - keys, values = self.kv_cache_manager.gather_kv_cache(req.request_id) + keys, values = self.cache_manager.gather_kv_cache(req.request_id) self.prefix_cache.cache_finished_request(req, keys, values) self.prefix_cache.evict_request(req.request_id) - self.kv_cache_manager.release_request(req.request_id) + self.cache_manager.release_request(req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " - f"kv cache manager has {self.kv_cache_manager.tokens_in_cache} tokens, " f"memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) self.scheduler.evict_request(req.request_id) @@ -296,7 +306,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: block_tables=prepared_inputs.get("block_tables"), context_lengths=prepared_inputs.get("context_lengths"), slot_mapping=prepared_inputs.get("slot_mapping"), - indexer_cache=prepared_inputs.get("indexer_cache"), + state_slot_mapping=prepared_inputs.get("state_slot_mapping"), ) logger.debug( @@ -317,14 +327,14 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Note: With PagedAttention, we don't need to explicitly update requests with new K/V # because they are written in-place to the global cache. - # self.kv_cache_manager.update_requests(...) is REMOVED. + # self.cache_manager.update_requests(...) is REMOVED. # Update prefix cache (TODO: Adapt to PagedKV) if self.enable_prefix_cache: pass # for _, req in enumerate(requests): # if req.is_prefill: - # keys, values = self.kv_cache_manager.gather_kv_cache(req.request_id) + # keys, values = self.cache_manager.gather_kv_cache(req.request_id) # self.prefix_cache.cache_unfinished_request(req, keys, values) # Process last peer: need additional sampling + detokenization @@ -339,8 +349,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: def _release_request(self, rid: str): """Release per-request resources in MLX.""" try: - if hasattr(self, "kv_cache_manager") and self.kv_cache_manager is not None: - self.kv_cache_manager.release_request(rid) + if hasattr(self, "cache_manager") and self.cache_manager is not None: + self.cache_manager.release_request(rid) except Exception: pass @@ -375,11 +385,11 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A # Allocate Paged KV blocks # For first peer and intermediate peers, we allocate based on prompt length - success = self.kv_cache_manager.allocate_request(req.request_id, req.total_length) + success = self.cache_manager.allocate_request(req.request_id, req.total_length) if not success: raise RuntimeError(f"OOM during prefill allocation for {req.request_id}") - block_table = self.kv_cache_manager.get_block_table(req.request_id) + block_table = self.cache_manager.get_block_table(req.request_id) block_tables_list.append(block_table) # For prefill, context length after this step will be total_length context_lengths_list.append(req.total_length) @@ -402,10 +412,10 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A for seq_idx in range(max_len): if seq_idx < length: # Valid token - block_idx = seq_idx // self.kv_cache_manager.block_size - block_offset = seq_idx % self.kv_cache_manager.block_size + block_idx = seq_idx // self.cache_manager.block_size + block_offset = seq_idx % self.cache_manager.block_size physical_block = block_table[block_idx] - slot = physical_block * self.kv_cache_manager.block_size + block_offset + slot = physical_block * self.cache_manager.block_size + block_offset slot_mapping_flat.append(slot) else: # Padding token @@ -427,16 +437,22 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A causal_mask = create_causal_mask(padded_inputs.shape[1], padded_inputs.shape[1], self.dtype) mask = combine_padding_and_causal_masks(padding_mask, causal_mask, self.dtype) + # Prepare state slot mapping if needed + state_slot_mapping = None + if self.cache_manager.needs_slots: + req_ids = [r.request_id for r in batched_requests] + slots = [self.cache_manager.get_slot(rid) for rid in req_ids] + state_slot_mapping = mx.array(slots, dtype=mx.int32) + ret = { "h_or_tokens": padded_inputs, - "cache": self.kv_cache_manager.get_cache(), - "indexer_cache": self.kv_cache_manager.get_indexer_cache(), + "cache": self.cache_manager.get_caches(), "mask": mask, "requests": batched_requests, "block_tables": block_tables_tensor, "context_lengths": context_lengths_tensor, "slot_mapping": slot_mapping_tensor, - "state_cache": None, + "state_slot_mapping": state_slot_mapping, } logger.debug(f"Prepared MLX prefill batch (size={batch_size})") return ret @@ -463,13 +479,13 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic # TODO: Prefix cache update # Allocate slot for new token - success = self.kv_cache_manager.append_slot(req.request_id) + success = self.cache_manager.append_slot(req.request_id) if not success: raise RuntimeError(f"OOM during decode for {req.request_id}") - block_table = self.kv_cache_manager.get_block_table(req.request_id) + block_table = self.cache_manager.get_block_table(req.request_id) block_tables_list.append(block_table) - context_lengths_list.append(self.kv_cache_manager.get_context_length(req.request_id)) + context_lengths_list.append(self.cache_manager.get_context_length(req.request_id)) if isinstance(h_or_tokens_list[0], list): # First peer case: h_or_tokens_list is list of list of ints [[token_id], ...] @@ -488,16 +504,22 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic block_tables_tensor = mx.array(padded_block_tables, dtype=mx.int32) context_lengths_tensor = mx.array(context_lengths_list, dtype=mx.int32) + # Prepare state slot mapping if needed + state_slot_mapping = None + if self.cache_manager.needs_slots: + req_ids = [r.request_id for r in batched_requests] + slots = [self.cache_manager.get_slot(rid) for rid in req_ids] + state_slot_mapping = mx.array(slots, dtype=mx.int32) + ret = { "h_or_tokens": padded_inputs, - "cache": self.kv_cache_manager.get_cache(), - "indexer_cache": self.kv_cache_manager.get_indexer_cache(), + "cache": self.cache_manager.get_caches(), "mask": None, "requests": batched_requests, "block_tables": block_tables_tensor, "context_lengths": context_lengths_tensor, "slot_mapping": None, - "state_cache": None, + "state_slot_mapping": state_slot_mapping, } logger.debug(f"Prepared MLX decode batch (size={batch_size})") return ret diff --git a/src/parallax/server/kv_cache.py b/src/parallax/server/kv_cache.py deleted file mode 100644 index 0aa708ea..00000000 --- a/src/parallax/server/kv_cache.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Simplified KV Cache Manager for Parallax Server - -This module implements a simplified key-value (KV) cache system to -avoid materializing the entier KV cache pool. -This is a dictionary-based approach where each request has its own growing KV cache. - -Core Components: - -KVCache: - - MLX-LM style growing cache that dynamically allocates memory as needed - - Supports efficient update and fetch operations - - Automatically handles memory expansion in chunks - -KVCacheManager: - - Uses a dictionary mapping request_id to KVCache instances - - Supports adding, updating, releasing requests' KV Cache - - Performs necessary memory checks to avoid exceeding limits -""" - -from typing import Dict, List, Optional, Tuple - -import mlx.core as mx - -from parallax.server.request import Request, RequestStatus -from parallax_utils.logging_config import get_logger -from parallax_utils.utils import compute_max_tokens_in_cache - -logger = get_logger(__name__) - - -class KVCache: - """Per-Request KV cache for a single request. - Dynamically grows the cache in chunks of block_size. - """ - - def __init__( - self, - num_kv_heads: int, - head_dim_k: int, - head_dim_v: int, - num_layers: int, - dtype: mx.Dtype, - block_size: int = 64, - conv_dim: Optional[int] = None, - conv_kernel_size: Optional[int] = None, - linear_k_dim: Optional[int] = None, - linear_v_dim: Optional[int] = None, - linear_num_k_heads: Optional[int] = None, - linear_num_v_heads: Optional[int] = None, - qk_nope_head_dim: Optional[int] = None, - qk_rope_head_dim: Optional[int] = None, - num_initial_tokens: int = 0, - ): - """ - Args: - num_kv_heads: The number of key-value heads. - head_dim: The dimension of each head. - num_layers: The number of layers. - dtype: The data type of the cache. - block_size: Source length dim growth step size. - num_initial_tokens: The number of tokens to initialize the cache with. - """ - self.num_kv_heads = num_kv_heads - self.dtype = dtype - self.block_size = block_size - self.conv_dim = conv_dim - self.conv_kernel_size = conv_kernel_size - self.linear_k_dim = linear_k_dim - self.linear_v_dim = linear_v_dim - self.linear_num_k_heads = linear_num_k_heads - self.linear_num_v_heads = linear_num_v_heads - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.head_dim_v = head_dim_v - self.head_dim_k = head_dim_k - - num_initial_tokens = self.round_up_to_step(num_initial_tokens) - # (num_layers, num_kv_heads, seq_len, head_dim) - - self.keys = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, self.head_dim_k), dtype) - self.values = mx.zeros( - (num_layers, num_kv_heads, num_initial_tokens, self.head_dim_v), dtype - ) - self.state0 = ( - mx.zeros((num_layers, conv_kernel_size - 1, conv_dim), dtype) if conv_dim else None - ) - - self.state1 = ( - mx.zeros((num_layers, linear_num_v_heads, linear_k_dim, linear_v_dim), dtype) - if (linear_k_dim and linear_v_dim and linear_num_k_heads and linear_num_v_heads) - else None - ) - self.num_tokens = num_initial_tokens - self.offset = 0 - - def round_up_to_step(self, seq_len: int) -> int: - """ - Rounds up to the nearest multiple of the block_size. - """ - return (seq_len + self.block_size - 1) // self.block_size * self.block_size - - def needs_grow(self, seq_len: int) -> bool: - """Checks if the cache needs to grow.""" - return (self.offset + seq_len) > self.num_tokens - - def fetch(self) -> Tuple[mx.array, mx.array]: - """Fetches the KV cache for the request.""" - return ( - self.keys[..., : self.offset, :], - self.values[..., : self.offset, :], - self.state0 if self.state0 is not None else None, - self.state1 if self.state1 is not None else None, - ) - - def update( - self, - keys: mx.array, - values: mx.array, - state0: Optional[mx.array], - state1: Optional[mx.array], - ) -> int: - """ - Updates the cache with new key-value pairs. - - Args: - keys: New keys to add, shape (num_layers, num_kv_heads, target_len, head_dim_k) - values: New values to add, shape (num_layers, num_kv_heads, target_len, head_dim_v) - """ - if state0 is not None and self.state0 is not None: - self.state0 = state0 - if state1 is not None and self.state1 is not None: - self.state1 = state1 - - prev = self.offset - seq_len = keys.shape[2] - prev_tokens = self.num_tokens - # Grow the cache based on the block_size size - if self.needs_grow(seq_len): - num_layers, num_kv_heads, _, head_dim_k = keys.shape - _, _, _, head_dim_v = values.shape - n_steps = (self.block_size + seq_len - 1) // self.block_size - k_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_k) - v_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_v) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - - if prev % self.block_size != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - self.num_tokens = self.keys.shape[2] - - # Update with new keys and values - self.offset += seq_len - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.num_tokens - prev_tokens - - -class KVCacheManager: - """Manager for KVCache instances.""" - - def __init__( - self, - num_kv_heads: int, - head_dim: int, - num_layers: int, - dtype: mx.Dtype, - block_size: int = 64, - max_num_tokens: Optional[int] = None, - cache_memory_fraction: float = 0.5, - conv_dim: Optional[int] = None, - conv_kernel_size: Optional[int] = None, - linear_k_dim: Optional[int] = None, - linear_v_dim: Optional[int] = None, - linear_num_k_heads: Optional[int] = None, - linear_num_v_heads: Optional[int] = None, - qk_nope_head_dim: Optional[int] = None, - qk_rope_head_dim: Optional[int] = None, - v_head_dim: Optional[int] = None, - ): - """ - Args: - num_kv_heads: The number of key-value heads. - head_dim: The dimension of each head. - num_layers: The number of layers. - dtype: The data type of the cache. - block_size: Source length dim growth step size. - max_num_tokens: The maximum number of tokens in the cache. - cache_memory_fraction: The fraction of memory to use for the cache. - """ - self.num_kv_heads = num_kv_heads - self.num_layers = num_layers - self.dtype = dtype - self.block_size = block_size - self.conv_dim = conv_dim - self.conv_kernel_size = conv_kernel_size - self.linear_k_dim = linear_k_dim - self.linear_v_dim = linear_v_dim - self.linear_num_k_heads = linear_num_k_heads - self.linear_num_v_heads = linear_num_v_heads - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - if qk_nope_head_dim and qk_rope_head_dim: - self.head_dim_k = qk_nope_head_dim + qk_rope_head_dim - else: - self.head_dim_k = head_dim - self.head_dim_v = v_head_dim if v_head_dim is not None else head_dim - - self.request_caches: Dict[str, KVCache] = {} - self.tokens_in_cache = 0 - - self.max_num_tokens = compute_max_tokens_in_cache( - device="mlx", - kv_cache_memory_fraction=cache_memory_fraction, - num_shard_layers=num_layers, - num_key_value_heads=num_kv_heads, - head_dim_k=self.head_dim_k, - head_dim_v=self.head_dim_v, - elem_bytes=dtype.size, - ) - if max_num_tokens is not None: - self.max_num_tokens = min(self.max_num_tokens, max_num_tokens) - - def round_up_to_step(self, seq_len: int) -> int: - """ - Rounds up to the nearest multiple of the block_size. - """ - return (seq_len + self.block_size - 1) // self.block_size * self.block_size - - def has_request(self, request_id: str) -> bool: - """ - Checks if the request is in the cache. - """ - return request_id in self.request_caches - - def request_length(self, request_id: str) -> int: - """ - Returns the length of key/value in the request. - """ - return self.request_caches[request_id].offset - - def request_num_tokens(self, request_id: str) -> int: - """ - Returns the number of tokens (including slots not yet filled) in the request. - """ - assert self.has_request(request_id), "request not in cache" - return self.request_caches[request_id].num_tokens - - def gather_kv_cache(self, request_id: str) -> Tuple[mx.array, mx.array]: - """ - Gathers the KV cache for the request. - """ - assert self.has_request(request_id), "request not in cache" - return self.request_caches[request_id].fetch() - - def add_request(self, request: Request, num_tokens: int = 128) -> bool: - """Adds a request to the cache. - - Args: - request: The request to add. - num_tokens: The number of tokens in the request. - - Returns: - True if the request is added. - """ - assert ( - request.status == RequestStatus.PREFILLING - ), "add_request can only be called for prefilling requests" - - if request.request_id in self.request_caches: - logger.warning(f"Request {request.request_id} already in cache") - return True - - num_tokens = self.round_up_to_step(num_tokens) - if self.tokens_in_cache + num_tokens > self.max_num_tokens: - logger.warning( - f"can't add request {request.request_id} to cache: {self.tokens_in_cache} + " - f"{num_tokens} > {self.max_num_tokens}" - ) - return False - - self.request_caches[request.request_id] = KVCache( - num_kv_heads=self.num_kv_heads, - head_dim_k=self.head_dim_k, - head_dim_v=self.head_dim_v, - num_layers=self.num_layers, - dtype=self.dtype, - block_size=self.block_size, - num_initial_tokens=num_tokens, - conv_dim=self.conv_dim, - conv_kernel_size=self.conv_kernel_size, - linear_k_dim=self.linear_k_dim, - linear_v_dim=self.linear_v_dim, - linear_num_k_heads=self.linear_num_k_heads, - linear_num_v_heads=self.linear_num_v_heads, - ) - self.tokens_in_cache += self.request_num_tokens(request.request_id) - return True - - # def add_request_with_prefix_cache(): - - def release_request(self, request_id: str) -> bool: - """ - Releases the request from the cache. - """ - assert self.has_request(request_id), "request not in cache" - self.tokens_in_cache -= self.request_num_tokens(request_id) - del self.request_caches[request_id] - return True - - def update_requests( - self, - requests: List[Request], - keys: mx.array, - values: mx.array, - lengths: List[int], - states0: Optional[mx.array], - states1: Optional[mx.array], - ) -> bool: - """ - Updates the requests in the cache. - - Args: - requests: The requests to update. - keys: The keys to update. - values: The values to update. - lengths: The lengths of the requests. - - Returns: - True if requests are updated. - """ - batch_size, num_layers, n_kv_heads, _, head_dim_k = keys.shape - _, _, _, _, head_dim_v = values.shape - # Validate - # assert keys.shape == values.shape, "key and value must have the same shape" - assert num_layers == self.num_layers, "key and value must have the same number of layers" - assert batch_size == len(requests), "key and value must have the same batch size" - assert len(lengths) == batch_size, "lengths must have the same batch size as requests" - assert ( - n_kv_heads == self.num_kv_heads - ), "key and value must have the same number of key-value heads" - assert head_dim_k == self.head_dim_k, "key and value must have the same head dimension" - assert head_dim_v == self.head_dim_v, "key and value must have the same head dimension" - # TODO: Use vmap for better performance - for request, key, value, length, state0, state1 in zip( - requests, keys, values, lengths, states0, states1 - ): - length = length.item() - assert self.has_request(request.request_id), "request not in cache" - # TODO: fix this - # actual length? double-counted prefill len - # decode length 1, why rounding up? - if self.tokens_in_cache + self.round_up_to_step(length) > self.max_num_tokens: - logger.warning( - f"can't add request {request.request_id} to cache: " - f"{self.tokens_in_cache} + {length} > {self.max_num_tokens}" - ) - return False - self.tokens_in_cache += self.request_caches[request.request_id].update( - key[..., :length, :], value[..., :length, :], state0, state1 - ) - return True - - def add_matched_prefix_request( - self, request: Request, key: mx.array, value: mx.array, length: int - ): - """If a request matches prefix, add it back to the running kv-cache manager""" - assert self.has_request(request.request_id), "request not in cache" - if self.tokens_in_cache + self.round_up_to_step(length) > self.max_num_tokens: - logger.warning( - f"can't add request {request.request_id} to cache: " - f"{self.tokens_in_cache} + {length} > {self.max_num_tokens}" - ) - return False - self.tokens_in_cache += self.request_caches[request.request_id].update( - key[..., :length, :], value[..., :length, :] - ) - return True diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index d4021ce6..6aeefe7a 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -2,7 +2,7 @@ Defines the ShardedModel class for distributing MLX models across multiple devices. """ -from typing import Optional, Tuple, Type +from typing import Any, List, Optional, Type import mlx.core as mx from mlx import nn @@ -58,7 +58,8 @@ def __init__( self.norm_in = None self.layers = [ - block_class(config, layer_idx) for layer_idx in range(start_layer, end_layer) + block_class(config, layer_idx, layer_idx - start_layer) + for layer_idx in range(start_layer, end_layer) ] if self.is_last_shard: @@ -109,12 +110,11 @@ def logits_to_tokens( def __call__( self, h_or_tokens: mx.array, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[List[Any]] = None, mask: Optional[mx.array] = None, block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, - window_size: Optional[int] = None, **kwargs, ) -> mx.array: """ @@ -122,9 +122,8 @@ def __call__( h_or_tokens: (batch, target_len_padded, D) or (batch, target_len_padded) for prefill, (batch, 1, D) or (batch, 1) for decode. - cache: PagedAttention: - (key_cache_global, value_cache_global) - has for shape: (num_layers, num_blocks, num_kv_heads, block_size, head_dim) + cache: List of layer caches (KVCache or LinearCache). + Legacy mode: (key_cache_global, value_cache_global) tuple. lengths: (batch,) true lengths of each sequence in batch. mask: Optional causal mask for the current segment. window_size: Optional int, if provided, will use a sliding window attention mask. diff --git a/src/parallax/server/paged_kv_cache.py b/src/parallax/server/paged_kv_cache.py deleted file mode 100644 index b7e715d0..00000000 --- a/src/parallax/server/paged_kv_cache.py +++ /dev/null @@ -1,266 +0,0 @@ -from typing import Dict, List, Optional, Set, Tuple - -import mlx.core as mx - -from parallax_utils.logging_config import get_logger - -logger = get_logger(__name__) - - -class BlockAllocator: - """Manages allocation of physical block indices.""" - - def __init__(self, num_blocks: int, block_size: int): - self.num_blocks = num_blocks - self.block_size = block_size - # Initialize free blocks stack - # Using a list as a stack is efficient for LIFO allocation - self.free_blocks: List[int] = list(range(num_blocks)) - # Keep track of used blocks for safety/debugging - self.used_blocks: Set[int] = set() - - def allocate(self, num_blocks_needed: int) -> List[int]: - """Allocates `num_blocks_needed` physical blocks.""" - if len(self.free_blocks) < num_blocks_needed: - # Out of memory - return [] - - # Pop blocks from the stack - split_idx = len(self.free_blocks) - num_blocks_needed - allocated = self.free_blocks[split_idx:] - self.free_blocks = self.free_blocks[:split_idx] - - for b in allocated: - self.used_blocks.add(b) - - return allocated - - def free(self, blocks: List[int]): - """Frees the given physical blocks.""" - for b in blocks: - if b in self.used_blocks: - self.used_blocks.remove(b) - self.free_blocks.append(b) - else: - logger.warning(f"Double free detected for block {b}") - - def get_num_free_blocks(self) -> int: - return len(self.free_blocks) - - def get_num_used_blocks(self) -> int: - return len(self.used_blocks) - - -class PagedKVCacheManager: - """ - Manages the Paged KV Cache tensors and block tables for requests. - """ - - def __init__( - self, - num_layers: int, - num_kv_heads: int, - head_dim: int, - dtype: mx.Dtype, - block_size: int = 16, - cache_memory_fraction: float = 0.8, - num_gpu_blocks: Optional[int] = None, - max_num_seqs: int = 256, # Max concurrent requests hint - head_dim_v: Optional[int] = None, - indexer_key_head_dim: Optional[int] = None, - indexer_num_kv_heads: Optional[int] = None, - ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim - self.indexer_key_head_dim = indexer_key_head_dim - self.indexer_num_kv_heads = indexer_num_kv_heads - self.dtype = dtype - self.block_size = block_size - self.max_num_seqs = max_num_seqs - - if num_gpu_blocks is None: - num_gpu_blocks = self._calculate_num_blocks(cache_memory_fraction, dtype) - - self.num_gpu_blocks = num_gpu_blocks - - # 1. Initialize Allocator - self.allocator = BlockAllocator(num_gpu_blocks, block_size) - - # 2. Allocate Global Cache Tensors - # Shape: (num_layers, num_blocks, num_kv_heads, block_size, head_dim) - logger.info( - f"Allocating Paged KV Cache: {num_gpu_blocks} blocks, {block_size} block_size, " - f"k_head_dim={self.head_dim}, v_head_dim={self.head_dim_v}" - ) - - self.key_cache = mx.zeros( - (num_layers, num_gpu_blocks, num_kv_heads, block_size, self.head_dim), dtype=dtype - ) - self.value_cache = mx.zeros( - (num_layers, num_gpu_blocks, num_kv_heads, block_size, self.head_dim_v), dtype=dtype - ) - - if self.indexer_key_head_dim is not None and self.indexer_num_kv_heads is not None: - logger.info( - f"Allocating Indexer Key Cache: {self.indexer_key_head_dim} head_dim, " - f"{self.indexer_num_kv_heads} heads" - ) - self.indexer_key_cache = mx.zeros( - ( - num_layers, - num_gpu_blocks, - self.indexer_num_kv_heads, - block_size, - self.indexer_key_head_dim, - ), - dtype=dtype, - ) - mx.eval(self.indexer_key_cache) - else: - self.indexer_key_cache = None - - # Ensure memory is materialized - mx.eval(self.key_cache, self.value_cache) - - # 3. Request State Management - # Mapping: request_id -> List of physical block indices - self.block_tables: Dict[str, List[int]] = {} - # Mapping: request_id -> current context length (number of tokens) - self.context_lengths: Dict[str, int] = {} - - def _calculate_num_blocks(self, cache_memory_fraction: float, dtype: mx.Dtype) -> int: - - device_info = mx.metal.device_info() - total_mem = device_info["max_recommended_working_set_size"] - current_mem = mx.metal.get_active_memory() - free_mem = total_mem - current_mem - - # We use a fraction of FREE memory, but for safety in multi-process/multi-model - # scenarios, we might want to base it on TOTAL memory fraction if we know - # what we are doing (as in Executor logic). - # However, to be safe and consistent with previous logic that tried to avoid OOM: - # Let's stick to the logic that available_for_kv is based on free memory - # OR total_memory * fraction if we trust the fraction to be per-process adjusted. - - # If fraction is small (e.g. < 0.2), it likely means it's per-process adjusted. - # But here we stick to "use what is available" to be safe. - available_for_kv = free_mem * cache_memory_fraction - - dtype_size = 2 if dtype in [mx.float16, mx.bfloat16] else 4 - - # Calculate bytes per block considering potentially different K and V head dimensions - key_block_bytes = ( - self.num_layers * self.num_kv_heads * self.block_size * self.head_dim * dtype_size - ) - value_block_bytes = ( - self.num_layers * self.num_kv_heads * self.block_size * self.head_dim_v * dtype_size - ) - indexer_block_bytes = 0 - if self.indexer_key_head_dim is not None and self.indexer_num_kv_heads is not None: - indexer_block_bytes = ( - self.num_layers - * self.indexer_num_kv_heads - * self.block_size - * self.indexer_key_head_dim - * dtype_size - ) - - block_bytes = key_block_bytes + value_block_bytes + indexer_block_bytes - - num_gpu_blocks = int(available_for_kv // block_bytes) - - if num_gpu_blocks <= 0: - logger.warning( - f"Not enough memory for KV cache. Total: {total_mem / 1024**3:.2f} GB, " - f"Used: {current_mem / 1024**3:.2f} GB, Free: {free_mem / 1024**3:.2f} GB. " - f"Defaulting to minimal 16 blocks." - ) - num_gpu_blocks = 16 - - logger.info( - f"PagedKVCache: Calculated num_gpu_blocks={num_gpu_blocks} based on " - f"fraction={cache_memory_fraction:.2f}, " - f"total_mem={total_mem/1024**3:.2f} GB, " - f"used_mem={current_mem/1024**3:.2f} GB, " - f"free_mem={free_mem/1024**3:.2f} GB, " - f"available_for_kv={available_for_kv/1024**3:.2f} GB" - ) - return num_gpu_blocks - - def get_num_free_blocks(self) -> int: - return self.allocator.get_num_free_blocks() - - def can_allocate(self, num_tokens: int) -> bool: - num_blocks = (num_tokens + self.block_size - 1) // self.block_size - return self.allocator.get_num_free_blocks() >= num_blocks - - def allocate_request(self, request_id: str, prompt_len: int) -> bool: - """ - Allocates initial blocks for a new request (Prefill). - Returns True if successful, False if OOM. - """ - if request_id in self.block_tables: - return True - - num_blocks = (prompt_len + self.block_size - 1) // self.block_size - blocks = self.allocator.allocate(num_blocks) - - if len(blocks) < num_blocks: - # Allocation failed - if blocks: - self.allocator.free(blocks) - return False - - self.block_tables[request_id] = blocks - self.context_lengths[request_id] = prompt_len - return True - - def has_request(self, request_id: str) -> bool: - return request_id in self.block_tables - - def free_request(self, request_id: str): - """Frees all blocks associated with a request.""" - if request_id in self.block_tables: - blocks = self.block_tables[request_id] - self.allocator.free(blocks) - del self.block_tables[request_id] - del self.context_lengths[request_id] - - def release_request(self, request_id: str): - """Alias for free_request to match Executor expectation.""" - self.free_request(request_id) - - def append_slot(self, request_id: str) -> bool: - """ - Allocates a new slot for the next token generation (Decode). - If the last block is full, allocates a new block. - """ - if request_id not in self.block_tables: - raise ValueError(f"Request {request_id} not found") - - current_len = self.context_lengths[request_id] - - if current_len % self.block_size == 0: - new_blocks = self.allocator.allocate(1) - if not new_blocks: - return False # OOM - self.block_tables[request_id].extend(new_blocks) - - self.context_lengths[request_id] += 1 - return True - - def get_block_table(self, request_id: str) -> List[int]: - return self.block_tables.get(request_id, []) - - def get_context_length(self, request_id: str) -> int: - return self.context_lengths.get(request_id, 0) - - def get_cache(self) -> Tuple[mx.array, mx.array]: - """Returns the global cache tensors.""" - return self.key_cache, self.value_cache - - def get_indexer_cache(self) -> Optional[mx.array]: - """Returns the global indexer key cache tensor.""" - return self.indexer_key_cache diff --git a/src/parallax/server/radix_cache.py b/src/parallax/server/radix_cache.py index 0bd07d27..9b52f4b8 100755 --- a/src/parallax/server/radix_cache.py +++ b/src/parallax/server/radix_cache.py @@ -12,7 +12,7 @@ import mlx.core as mx -from parallax.server.kv_cache import KVCache +from parallax.server.cache.kv_cache import KVCache from parallax.server.request import Request diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 411388f8..20de0bda 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -23,7 +23,7 @@ from collections import OrderedDict from typing import Dict, List, Optional -from parallax.server.kv_cache import KVCacheManager +from parallax.server.cache_manager import CacheManager from parallax.server.request import InitialRequest, Request, RequestStatus from parallax.utils.shared_state import SharedState from parallax_utils.logging_config import get_logger @@ -45,7 +45,7 @@ def __init__( scheduler_wait_ms: int = 200, micro_batch_ratio: int = 2, is_first_peer: bool = False, - kv_cache_manager: Optional[KVCacheManager] = None, + cache_manager: Optional[CacheManager] = None, request_timeout_s: Optional[int] = 600, shared_state: Optional[SharedState] = None, **kwargs, @@ -57,7 +57,7 @@ def __init__( scheduler_wait_ms: The minimum time to wait before dispatching a batch; micro_batch_ratio: micro_batch_size = max_batch_size // micro_batch_ratio; tokenizer: The tokenizer to use for the model; - kv_cache_manager: The KV cache manager to use for the scheduler. + cache_manager: The KV cache manager to use for the scheduler. request_timeout_s: timeout for each inflight request (default 10mins). """ self.max_batch_size = max_batch_size @@ -77,7 +77,7 @@ def __init__( # Keeps track of all in-flight requests self._running_requests: Dict[str, Request] = OrderedDict() - self.kv_cache_manager = kv_cache_manager + self.cache_manager = cache_manager self.shared_state = shared_state # Default timeout for requests if not set on request object self.request_timeout_s = request_timeout_s @@ -225,10 +225,10 @@ def admit_requests(self): continue # Check kv cache pool - if self.kv_cache_manager is not None: - if not self.kv_cache_manager.has_request(req.request_id): + if self.cache_manager is not None: + if not self.cache_manager.has_request(req.request_id): # TODO: Handle chunked prefill, and support preemption. - if not self.kv_cache_manager.allocate_request(req.request_id, req.total_length): + if not self.cache_manager.allocate_request(req.request_id, req.total_length): logger.warning( f"Request {rid} can't be admit to running batch due to KV cache size." ) diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 4a5264ec..1bdabfe5 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -312,3 +312,40 @@ def initialize_nccl_port(): else: nccl_port -= 43 return nccl_port + + +def get_layer_types(config: dict, start_layer: int, end_layer: int) -> List[str]: + num_shard_layers = end_layer - start_layer + + # Case 1: Explicit layer types (e.g., DeepSeek with layers_block_type) + layer_types = config.get("layers_block_type", None) + if layer_types is not None: + if len(layer_types) >= end_layer: + layer_types = layer_types[start_layer:end_layer] + return [ + "linear" if t in ["mamba", "linear_attention"] else "attention" for t in layer_types + ] + + # Case 2: linear_attn_config with full_attn_layers (e.g., Kimi) + linear_attn_config = config.get("linear_attn_config") + if linear_attn_config: + full_attn_layers = set(linear_attn_config.get("full_attn_layers", [])) + layer_types = [] + for i in range(start_layer, end_layer): + if i in full_attn_layers: + layer_types.append("attention") + else: + layer_types.append("linear") + return layer_types + + # Case 3: full_attention_interval (e.g., Qwen3Next) + full_attention_interval = config.get("full_attention_interval") + if full_attention_interval: + layer_types = [] + for i in range(start_layer, end_layer): + is_linear = (i + 1) % full_attention_interval != 0 + layer_types.append("linear" if is_linear else "attention") + return layer_types + + # Default: all attention layers + return ["attention"] * num_shard_layers diff --git a/tests/test_batch_scheduler.py b/tests/test_batch_scheduler.py index f2b94079..2c902f0d 100644 --- a/tests/test_batch_scheduler.py +++ b/tests/test_batch_scheduler.py @@ -2,7 +2,7 @@ from parallax.server.scheduler import Scheduler -class FakeKVCacheManager: +class FakeCacheManager: def __init__(self, allow: bool = True): self.allow = allow self._reqs = set() @@ -92,12 +92,12 @@ def test_token_budget_prefill_skipped_decode_taken(): def test_kv_cache_admission_guard_blocks_prefill(): # A KV manager that rejects additions - kv_mgr = FakeKVCacheManager(allow=False) + cache_mgr = FakeCacheManager(allow=False) sched = Scheduler( max_batch_size=2, max_num_tokens_per_batch=100, micro_batch_ratio=1, - kv_cache_manager=kv_mgr, + cache_manager=cache_mgr, ) p = make_prefill("p", 4) sched.enque_request(p) diff --git a/tests/test_model.py b/tests/test_model.py index a7c5718e..f66fa729 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ from mlx_lm.models.base import create_attention_mask from mlx_lm.utils import _download, load_model -from parallax.server.paged_kv_cache import PagedKVCacheManager +from parallax.server.cache_manager import CacheManager from parallax.server.shard_loader import MLXModelLoader from parallax.utils.tokenizer_utils import load_tokenizer from parallax.utils.utils import pad_inputs @@ -87,12 +87,12 @@ def _call_with_mask(self, inputs, cache=None, mask=None): "num_attention_heads" ) - kv_cache_managers = [] + cache_managers = [] cache_memory_fraction = 0 for shard in model_shards: num_shard_layers = shard.end_layer - shard.start_layer cache_memory_fraction += 0.1 - kv_mgr = PagedKVCacheManager( + cache_mgr = CacheManager( num_layers=num_shard_layers, num_kv_heads=num_kv_heads, head_dim=head_dim, @@ -100,7 +100,7 @@ def _call_with_mask(self, inputs, cache=None, mask=None): block_size=64, num_gpu_blocks=200, ) - kv_cache_managers.append(kv_mgr) + cache_managers.append(cache_mgr) # Prepare common inputs padding_mask = (ref_ids != ref_pad_token_id).astype(dtype) @@ -109,7 +109,7 @@ def _call_with_mask(self, inputs, cache=None, mask=None): # Run sharded models x = None for shard_idx, shard in enumerate(model_shards): - kv_cache_manager = kv_cache_managers[shard_idx] + cache_manager = cache_managers[shard_idx] # Allocate blocks and prepare metadata block_tables_list = [] @@ -121,19 +121,19 @@ def _call_with_mask(self, inputs, cache=None, mask=None): seq_len = actual_seq_lengths[i] context_lengths_list.append(seq_len) - success = kv_cache_manager.allocate_request(req_id, seq_len) + success = cache_manager.allocate_request(req_id, seq_len) assert success, f"Failed to allocate blocks for request {i} in shard {shard_idx}" - block_table = kv_cache_manager.get_block_table(req_id) + block_table = cache_manager.get_block_table(req_id) block_tables_list.append(block_table) # Generate slot mapping for seq_idx in range(max_seq_len): if seq_idx < seq_len: - block_idx = seq_idx // kv_cache_manager.block_size - block_offset = seq_idx % kv_cache_manager.block_size + block_idx = seq_idx // cache_manager.block_size + block_offset = seq_idx % cache_manager.block_size physical_block = block_table[block_idx] - slot = physical_block * kv_cache_manager.block_size + block_offset + slot = physical_block * cache_manager.block_size + block_offset slot_mapping_flat.append(slot) else: slot_mapping_flat.append(-1) @@ -145,7 +145,7 @@ def _call_with_mask(self, inputs, cache=None, mask=None): block_tables = mx.array(padded_block_tables, dtype=mx.int32) context_lengths = mx.array(context_lengths_list, dtype=mx.int32) slot_mapping = mx.array(slot_mapping_flat, dtype=mx.int64) - cache = kv_cache_manager.get_cache() + cache = cache_manager.get_caches() # Forward pass input_data = ref_ids if shard.start_layer == 0 else x diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py index a70b5682..26b78b3f 100644 --- a/tests/test_paged_attention.py +++ b/tests/test_paged_attention.py @@ -68,19 +68,13 @@ def test_basic_functionality(self, dtype): NUM_KV_HEADS = 4 HEAD_DIM = 32 BLOCK_SIZE = 16 - NUM_LAYERS = 1 NUM_BLOCKS = 1024 - LAYER_IDX = 0 SCALE = 1.0 / math.sqrt(HEAD_DIM) atol = 1e-2 if dtype != mx.float32 else 1e-4 - # Setup Memory - key_cache = mx.zeros( - (NUM_LAYERS, NUM_BLOCKS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM), dtype=dtype - ) - value_cache = mx.zeros( - (NUM_LAYERS, NUM_BLOCKS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM), dtype=dtype - ) + # Setup Memory (single layer cache, shape: (1, num_blocks, num_kv_heads, block_size, head_dim)) + key_cache = mx.zeros((1, NUM_BLOCKS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM), dtype=dtype) + value_cache = mx.zeros((1, NUM_BLOCKS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM), dtype=dtype) # Mock Block Tables max_blocks_per_req = 2 @@ -104,7 +98,6 @@ def test_basic_functionality(self, dtype): block_tables, context_lengths, BLOCK_SIZE, - LAYER_IDX, ) mx.eval(new_k_cache, new_v_cache) @@ -139,7 +132,6 @@ def test_basic_functionality(self, dtype): BLOCK_SIZE, SCALE, NUM_KV_HEADS, - LAYER_IDX, ) mx.eval(output) @@ -207,7 +199,7 @@ def test_large_scale_correctness(self, params): num_blocks_per_req = (seq_len + block_size - 1) // block_size total_blocks = num_blocks_per_req * batch_size - # Setup Cache + # Setup Cache (single layer, shape: (1, total_blocks, num_kv_heads, block_size, head_dim)) key_cache = mx.zeros((1, total_blocks, num_kv_heads, block_size, head_dim), dtype=dtype) value_cache = mx.zeros((1, total_blocks, num_kv_heads, block_size, head_dim), dtype=dtype) @@ -254,7 +246,7 @@ def test_large_scale_correctness(self, params): value_cache = v_ready[None, ...] mx.eval(key_cache, value_cache) - # Run Kernel + # Run Kernel (no layer_idx needed) out = paged_attention( q, key_cache, @@ -264,7 +256,6 @@ def test_large_scale_correctness(self, params): block_size, scale, num_kv_heads, - 0, ) mx.eval(out) @@ -356,7 +347,6 @@ def test_benchmark_paged_vs_native(self): block_size, scale, num_kv_heads, - 0, ) mx.eval(_) @@ -371,7 +361,6 @@ def test_benchmark_paged_vs_native(self): block_size, scale, num_kv_heads, - 0, ) mx.eval(out) end = time.perf_counter() diff --git a/tests/test_paged_kv_integration.py b/tests/test_paged_kv_integration.py index 6158be91..6fc2a8d1 100644 --- a/tests/test_paged_kv_integration.py +++ b/tests/test_paged_kv_integration.py @@ -4,7 +4,7 @@ import numpy as np from parallax.metal.paged_attention.kernel import reshape_and_cache -from parallax.server.paged_kv_cache import PagedKVCacheManager +from parallax.server.cache_manager import CacheManager class TestPagedKVIntegration(unittest.TestCase): @@ -19,7 +19,7 @@ def setUp(self): # Mocking device info to avoid OOM or device dependency in test env # Assuming cache_memory_fraction results in enough blocks # We will manually override num_gpu_blocks if needed or rely on default fallback - self.cache_manager = PagedKVCacheManager( + self.cache_manager = CacheManager( num_layers=self.num_layers, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, @@ -103,8 +103,9 @@ def test_prefill_slot_mapping(self): slot_mapping_tensor = mx.array(slot_mapping_flat, dtype=mx.int64) - # 4. Run Kernel - key_cache, value_cache = self.cache_manager.get_cache() + # 4. Run Kernel (get cache for layer 0) + layer_cache = self.cache_manager.get_caches()[0] + key_cache, value_cache = layer_cache.get_cache() reshape_and_cache( keys_flat, @@ -114,7 +115,6 @@ def test_prefill_slot_mapping(self): block_tables_tensor, context_lengths_tensor, self.block_size, - layer_idx=0, slot_mapping=slot_mapping_tensor, ) @@ -125,7 +125,10 @@ def test_prefill_slot_mapping(self): # req1 fits in 1 block (block_size=16) block_idx_req1 = block_tables[0][0] # Check first token - cached_k_0 = key_cache[0, block_idx_req1, :, 0, :] # layer 0, block, heads, offset 0, dim + # key_cache shape: (1, num_blocks, num_kv_heads, block_size, head_dim) + cached_k_0 = key_cache[ + 0, block_idx_req1, :, 0, : + ] # dim0=placeholder, block, heads, offset 0, dim expected_k_0 = mx.array(keys_np[0, 0, :, :]) self.assertTrue(mx.allclose(cached_k_0, expected_k_0).item(), "Req1 Token 0 Key Mismatch")