File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed
Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -38,16 +38,16 @@ def propose(
3838 self ,
3939 target_hidden_states : torch .Tensor ,
4040 sampling_metadata : SamplingMetadata ,
41- ) -> list [ list [ int ]] :
41+ ) -> torch . Tensor :
4242 # Generate blocks and compute logits
4343 blocks = self .model (target_hidden_states )
4444 logits = self .model .compute_logits (blocks )
4545
46- # Get draft tokens and transpose the result
47- # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
48- # synchronization.
49- draft_tokens = [ logit . argmax ( dim = - 1 ). tolist () for logit in logits ]
50- return [ list ( row ) for row in zip ( * draft_tokens )]
46+ # Compute argmax for each Medusa head and stack into a single tensor
47+ # Shape: [batch_size, num_heads]
48+ draft_tokens = torch . stack ([ logit . argmax ( dim = - 1 ) for logit in logits ], dim = 1 )
49+
50+ return draft_tokens
5151
5252 def load_model (self , target_model : nn .Module ) -> None :
5353 from vllm .compilation .backends import set_model_tag
You can’t perform that action at this time.
0 commit comments