Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/parallax/metal/indexer/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def store_indexer_cache(
block_tables: mx.array,
context_lengths: mx.array,
block_size: int,
layer_idx: int,
slot_mapping: Optional[mx.array] = None,
):
dtype = key.dtype
Expand Down Expand Up @@ -114,7 +113,6 @@ def mk_int(val):
mk_int(num_heads),
mk_int(head_dim),
mk_int(block_size),
mk_int(layer_idx),
mk_int(num_layers),
mk_int(num_blocks),
]
Expand All @@ -127,7 +125,6 @@ def mk_int(val):
"num_heads",
"head_dim",
"block_size",
"layer_idx",
"num_layers",
"num_blocks",
]
Expand Down Expand Up @@ -160,7 +157,6 @@ def q_dot_k(
block_table: mx.array, # (max_blocks)
context_length: mx.array, # scalar
block_size: int,
layer_idx: int,
) -> mx.array:

if q.ndim > 2:
Expand All @@ -186,7 +182,6 @@ def mk_int(val):
mk_int(block_size),
mk_int(num_heads),
mk_int(head_dim),
mk_int(layer_idx),
mk_int(num_layers),
mk_int(num_total_blocks),
mk_int(max_blocks),
Expand All @@ -200,7 +195,6 @@ def mk_int(val):
"block_size",
"num_heads",
"head_dim",
"layer_idx",
"num_layers",
"num_total_blocks",
"max_blocks",
Expand Down
7 changes: 2 additions & 5 deletions src/parallax/metal/indexer/q_dot_k.metal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Inputs provided by MLX wrapper:
// q, key_cache, block_table, output (pointers)
// context_len, block_size, num_heads, head_dim, layer_idx, num_layers, num_total_blocks, max_blocks (scalars)
// context_len, block_size, num_heads, head_dim, num_layers, num_total_blocks, max_blocks (scalars)

uint3 gid = thread_position_in_grid;

Expand All @@ -15,16 +15,13 @@ long k_block_stride = num_heads * block_size * head_dim;
long k_head_stride = block_size * head_dim;
long k_layer_stride = (long)num_total_blocks * k_block_stride;

long layer_offset = (long)layer_idx * k_layer_stride;

for (int b = 0; b < num_valid_blocks; b++) {
int block_num = block_table[b];
int logical_idx = b * block_size + token_in_block;

if (logical_idx >= context_len) continue;

long k_base = layer_offset +
(long)block_num * k_block_stride +
long k_base = (long)block_num * k_block_stride +
head_idx * k_head_stride +
token_in_block * head_dim;

Expand Down
5 changes: 2 additions & 3 deletions src/parallax/metal/indexer/store_key.metal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Inputs provided by MLX wrapper:
// key, key_cache, slot_mapping (pointers)
// key_stride, num_heads, head_dim, block_size, layer_idx, num_layers, num_blocks (scalars)
// key_stride, num_heads, head_dim, block_size, num_layers, num_blocks (scalars)

device {{T}} *key_cache_mut = (device {{T}} *)key_cache;

Expand All @@ -27,8 +27,7 @@ long k_head_stride = block_size * head_dim;

long k_layer_stride = (long)num_blocks * k_block_stride;

long dest_idx = (long)layer_idx * k_layer_stride +
block_idx * k_block_stride +
long dest_idx = block_idx * k_block_stride +
head_idx * k_head_stride +
block_offset * head_dim +
d_idx;
Expand Down
30 changes: 0 additions & 30 deletions src/parallax/metal/paged_attention/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def reshape_and_cache(
block_tables: mx.array, # (batch, max_blocks)
context_lengths: mx.array, # (batch,)
block_size: int,
layer_idx: int,
slot_mapping: Optional[mx.array] = None, # (batch,) or (batch * target_len,)
):
"""
Expand Down Expand Up @@ -130,25 +129,14 @@ def reshape_and_cache(
if slot_mapping.shape[0] != num_tokens:
raise ValueError(f"Slot mapping length {slot_mapping.shape[0]} != tokens {num_tokens}")

num_layers = key_cache.shape[0]
num_blocks = key_cache.shape[1]

# 2. Prepare Constants
key_stride = num_kv_heads * k_head_dim
value_stride = num_kv_heads * v_head_dim

def mk_int(val):
return mx.array(val, dtype=mx.int32)

c_key_stride = mk_int(key_stride)
c_val_stride = mk_int(value_stride)
c_num_kv = mk_int(num_kv_heads)
c_k_head_dim = mk_int(k_head_dim)
c_v_head_dim = mk_int(v_head_dim)
c_block_size = mk_int(block_size)
c_layer_idx = mk_int(layer_idx)
c_num_layers = mk_int(num_layers)
c_num_blocks = mk_int(num_blocks)

# Inputs list
inputs = [
Expand All @@ -157,15 +145,10 @@ def mk_int(val):
key_cache,
value_cache,
slot_mapping,
c_key_stride,
c_val_stride,
c_num_kv,
c_k_head_dim,
c_v_head_dim,
c_block_size,
c_layer_idx,
c_num_layers,
c_num_blocks,
]

# Input names (just for declaration)
Expand All @@ -175,15 +158,10 @@ def mk_int(val):
"key_cache",
"value_cache",
"slot_mapping",
"key_stride",
"value_stride",
"num_kv_heads",
"k_head_dim",
"v_head_dim",
"block_size",
"layer_idx",
"num_layers",
"num_blocks",
]

# 3. Get and Launch Kernel
Expand Down Expand Up @@ -225,7 +203,6 @@ def paged_attention(
block_size: int,
scale: float,
num_kv_heads: int,
layer_idx: int,
v_head_dim: Optional[int] = None,
top_k_indices: Optional[mx.array] = None,
window_size: Optional[int] = None,
Expand Down Expand Up @@ -261,7 +238,6 @@ def mk_int(val):
c_v_head_dim = mk_int(v_head_dim)
c_block_size = mk_int(block_size)
c_max_blocks = mk_int(max_blocks)
c_layer_idx = mk_int(layer_idx)
c_num_layers = mk_int(num_layers)
c_num_total_blocks = mk_int(num_total_blocks)
c_scale = mx.array(scale, dtype=queries.dtype)
Expand All @@ -283,7 +259,6 @@ def mk_int(val):
c_v_head_dim,
c_block_size,
c_max_blocks,
c_layer_idx,
c_num_layers,
c_num_total_blocks,
c_scale,
Expand All @@ -303,7 +278,6 @@ def mk_int(val):
"v_head_dim",
"block_size",
"max_blocks",
"layer_idx",
"num_layers",
"num_total_blocks",
"scale",
Expand All @@ -327,7 +301,6 @@ def mk_int(val):
c_v_head_dim,
c_block_size,
c_max_blocks,
c_layer_idx,
c_num_layers,
c_num_total_blocks,
c_scale,
Expand All @@ -347,7 +320,6 @@ def mk_int(val):
"v_head_dim",
"block_size",
"max_blocks",
"layer_idx",
"num_layers",
"num_total_blocks",
"scale",
Expand All @@ -368,7 +340,6 @@ def mk_int(val):
c_v_head_dim,
c_block_size,
c_max_blocks,
c_layer_idx,
c_num_layers,
c_num_total_blocks,
c_scale,
Expand All @@ -386,7 +357,6 @@ def mk_int(val):
"v_head_dim",
"block_size",
"max_blocks",
"layer_idx",
"num_layers",
"num_total_blocks",
"scale",
Expand Down
15 changes: 3 additions & 12 deletions src/parallax/metal/paged_attention/paged_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Inputs:
// queries, key_cache, value_cache, block_tables, context_lengths
// output (output array)
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx,
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks,
// num_layers, num_total_blocks, scale (All pointers)

uint3 gid = thread_position_in_grid;
Expand All @@ -24,7 +24,6 @@ int _k_head_dim = k_head_dim;
int _v_head_dim = v_head_dim;
int _block_size = block_size;
int _max_blocks = max_blocks;
int _layer_idx = layer_idx;
int _num_total_blocks = num_total_blocks;
float _scale = scale;

Expand Down Expand Up @@ -57,29 +56,21 @@ int context_len = context_lengths[batch_idx];
int num_context_blocks = (context_len + _block_size - 1) / _block_size;

// Strides for Key
long k_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim;
long k_block_stride = _num_kv_heads * _block_size * _k_head_dim;
long k_head_stride = _block_size * _k_head_dim;

long k_layer_offset = _layer_idx * k_layer_stride;

// Strides for Value
long v_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim;
long v_block_stride = _num_kv_heads * _block_size * _v_head_dim;
long v_head_stride = _block_size * _v_head_dim;

long v_layer_offset = _layer_idx * v_layer_stride;

// Iterate over blocks
for (int b = 0; b < num_context_blocks; b++) {
int block_num = block_tables[batch_idx * _max_blocks + b];

long k_block_base =
k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride;
block_num * k_block_stride + kv_head_idx * k_head_stride;
long v_block_base =
v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride;
block_num * v_block_stride + kv_head_idx * v_head_stride;

int tokens_in_block = _block_size;
if (b == num_context_blocks - 1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Inputs:
// queries, key_cache, value_cache, block_tables, context_lengths, top_k_indices
// output (output array)
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx,
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks,
// num_layers, num_total_blocks, scale, index_topk (All pointers)

uint3 gid = thread_position_in_grid;
Expand All @@ -23,7 +23,6 @@ int _k_head_dim = k_head_dim;
int _v_head_dim = v_head_dim;
int _block_size = block_size;
int _max_blocks = max_blocks;
int _layer_idx = layer_idx;
int _num_total_blocks = num_total_blocks;
float _scale = scale;
int _index_topk = index_topk;
Expand Down Expand Up @@ -56,21 +55,13 @@ float acc_vec[8] = {0.0f};
int context_len = context_lengths[batch_idx];

// Strides for Key
long k_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim;
long k_block_stride = _num_kv_heads * _block_size * _k_head_dim;
long k_head_stride = _block_size * _k_head_dim;

long k_layer_offset = _layer_idx * k_layer_stride;

// Strides for Value
long v_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim;
long v_block_stride = _num_kv_heads * _block_size * _v_head_dim;
long v_head_stride = _block_size * _v_head_dim;

long v_layer_offset = _layer_idx * v_layer_stride;

// Check if we use Full Attention or Sparse Attention
// We check the first element of top_k_indices for this batch
int first_topk_idx = top_k_indices[batch_idx * _index_topk];
Expand All @@ -85,9 +76,9 @@ if (first_topk_idx == -1) {
int block_num = block_tables[batch_idx * _max_blocks + b];

long k_block_base =
k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride;
block_num * k_block_stride + kv_head_idx * k_head_stride;
long v_block_base =
v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride;
block_num * v_block_stride + kv_head_idx * v_head_stride;

int tokens_in_block = _block_size;
if (b == num_context_blocks - 1) {
Expand Down Expand Up @@ -140,9 +131,9 @@ if (first_topk_idx == -1) {
int block_num = block_tables[batch_idx * _max_blocks + block_idx_in_table];

long k_block_base =
k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride;
block_num * k_block_stride + kv_head_idx * k_head_stride;
long v_block_base =
v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride;
block_num * v_block_stride + kv_head_idx * v_head_stride;

// Compute Dot Product Q * K[token_idx]
float score = 0.0f;
Expand Down
15 changes: 3 additions & 12 deletions src/parallax/metal/paged_attention/paged_attention_gpt_oss.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Inputs:
// queries, key_cache, value_cache, block_tables, context_lengths, sinks
// output (output array)
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks, layer_idx,
// num_heads, num_kv_heads, k_head_dim, v_head_dim, block_size, max_blocks,
// num_layers, num_total_blocks, scale, window_size (All pointers)

uint3 gid = thread_position_in_grid;
Expand All @@ -24,7 +24,6 @@ int _k_head_dim = k_head_dim;
int _v_head_dim = v_head_dim;
int _block_size = block_size;
int _max_blocks = max_blocks;
int _layer_idx = layer_idx;
int _num_total_blocks = num_total_blocks;
float _scale = scale;
int _window_size = window_size;
Expand Down Expand Up @@ -78,21 +77,13 @@ int context_len = context_lengths[batch_idx];
int num_context_blocks = (context_len + _block_size - 1) / _block_size;

// Strides for Key
long k_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _k_head_dim;
long k_block_stride = _num_kv_heads * _block_size * _k_head_dim;
long k_head_stride = _block_size * _k_head_dim;

long k_layer_offset = _layer_idx * k_layer_stride;

// Strides for Value
long v_layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _v_head_dim;
long v_block_stride = _num_kv_heads * _block_size * _v_head_dim;
long v_head_stride = _block_size * _v_head_dim;

long v_layer_offset = _layer_idx * v_layer_stride;

// Iterate over blocks
for (int b = 0; b < num_context_blocks; b++) {
int block_num = block_tables[batch_idx * _max_blocks + b];
Expand All @@ -117,9 +108,9 @@ for (int b = 0; b < num_context_blocks; b++) {
}

long k_block_base =
k_layer_offset + block_num * k_block_stride + kv_head_idx * k_head_stride;
block_num * k_block_stride + kv_head_idx * k_head_stride;
long v_block_base =
v_layer_offset + block_num * v_block_stride + kv_head_idx * v_head_stride;
block_num * v_block_stride + kv_head_idx * v_head_stride;

int tokens_in_block = _block_size;
if (b == num_context_blocks - 1) {
Expand Down
Loading