Skip to content

Commit 4c2e10e

Browse files
[Bugfix] Fix cuda graph sizes when running with speculative decoding (#30330)
Signed-off-by: Patryk Saffer <patryk.saffer99@gmail.com> Signed-off-by: PatrykSaffer <patryk.saffer@mistral.ai> Co-authored-by: Patryk Saffer <patryk.saffer99@gmail.com>
1 parent 03b5f94 commit 4c2e10e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

vllm/config/vllm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,8 +1047,14 @@ def _set_cudagraph_sizes(self):
10471047
self.compilation_config.max_cudagraph_capture_size
10481048
)
10491049
if max_cudagraph_capture_size is None:
1050+
decode_query_len = 1
1051+
if (
1052+
self.speculative_config
1053+
and self.speculative_config.num_speculative_tokens
1054+
):
1055+
decode_query_len += self.speculative_config.num_speculative_tokens
10501056
max_cudagraph_capture_size = min(
1051-
self.scheduler_config.max_num_seqs * 2, 512
1057+
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
10521058
)
10531059
max_num_tokens = self.scheduler_config.max_num_batched_tokens
10541060
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

0 commit comments

Comments
 (0)