Skip to content

Commit 03b5f94

Browse files
authored
[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync (#29723)
Signed-off-by: dongbo910220 <1275604947@qq.com>
1 parent 2e7054d commit 03b5f94

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/v1/spec_decode/medusa.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ def propose(
3838
self,
3939
target_hidden_states: torch.Tensor,
4040
sampling_metadata: SamplingMetadata,
41-
) -> list[list[int]]:
41+
) -> torch.Tensor:
4242
# Generate blocks and compute logits
4343
blocks = self.model(target_hidden_states)
4444
logits = self.model.compute_logits(blocks)
4545

46-
# Get draft tokens and transpose the result
47-
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
48-
# synchronization.
49-
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
50-
return [list(row) for row in zip(*draft_tokens)]
46+
# Compute argmax for each Medusa head and stack into a single tensor
47+
# Shape: [batch_size, num_heads]
48+
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
49+
50+
return draft_tokens
5151

5252
def load_model(self, target_model: nn.Module) -> None:
5353
from vllm.compilation.backends import set_model_tag

0 commit comments

Comments
 (0)