Skip to content

Commit 16b24e7

Browse files
authored
[Bugfix] Bandaid fix for speculative decoding tests (#9327)
1 parent f519902 commit 16b24e7

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/worker/model_runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import vllm.envs as envs
1818
from vllm.attention import AttentionMetadata, get_attn_backend
1919
from vllm.attention.backends.abstract import AttentionState
20+
from vllm.attention.backends.utils import CommonAttentionState
2021
from vllm.compilation.compile_context import set_compile_context
2122
from vllm.compilation.levels import CompilationLevel
2223
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
@@ -1001,16 +1002,30 @@ def __init__(
10011002
self.graph_block_tables = np.zeros(
10021003
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
10031004
dtype=np.int32)
1005+
1006+
# Attention-free but stateful models like Mamba need a placeholder attn
1007+
# backend, as the attention metadata is needed to manage internal state.
1008+
# However we must bypass attention selection altogether for some models
1009+
# used for speculative decoding to avoid a divide-by-zero in
1010+
# model_config.get_head_size()
1011+
num_attn_heads = self.model_config.get_num_attention_heads(
1012+
self.parallel_config)
1013+
needs_attn_backend = (num_attn_heads != 0
1014+
or self.model_config.is_attention_free)
1015+
10041016
self.attn_backend = get_attn_backend(
10051017
self.model_config.get_head_size(),
10061018
self.model_config.get_sliding_window(),
10071019
self.model_config.dtype,
10081020
self.kv_cache_dtype,
10091021
self.block_size,
10101022
self.model_config.is_attention_free,
1011-
)
1012-
self.attn_state = self.attn_backend.get_state_cls()(
1013-
weakref.proxy(self))
1023+
) if needs_attn_backend else None
1024+
if self.attn_backend:
1025+
self.attn_state = self.attn_backend.get_state_cls()(
1026+
weakref.proxy(self))
1027+
else:
1028+
self.attn_state = CommonAttentionState(weakref.proxy(self))
10141029

10151030
# Multi-modal data support
10161031
self.input_registry = input_registry

0 commit comments

Comments
 (0)