|
17 | 17 | import vllm.envs as envs
|
18 | 18 | from vllm.attention import AttentionMetadata, get_attn_backend
|
19 | 19 | from vllm.attention.backends.abstract import AttentionState
|
| 20 | +from vllm.attention.backends.utils import CommonAttentionState |
20 | 21 | from vllm.compilation.compile_context import set_compile_context
|
21 | 22 | from vllm.compilation.levels import CompilationLevel
|
22 | 23 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
@@ -1001,16 +1002,30 @@ def __init__(
|
1001 | 1002 | self.graph_block_tables = np.zeros(
|
1002 | 1003 | (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
1003 | 1004 | dtype=np.int32)
|
| 1005 | + |
| 1006 | + # Attention-free but stateful models like Mamba need a placeholder attn |
| 1007 | + # backend, as the attention metadata is needed to manage internal state. |
| 1008 | + # However we must bypass attention selection altogether for some models |
| 1009 | + # used for speculative decoding to avoid a divide-by-zero in |
| 1010 | + # model_config.get_head_size() |
| 1011 | + num_attn_heads = self.model_config.get_num_attention_heads( |
| 1012 | + self.parallel_config) |
| 1013 | + needs_attn_backend = (num_attn_heads != 0 |
| 1014 | + or self.model_config.is_attention_free) |
| 1015 | + |
1004 | 1016 | self.attn_backend = get_attn_backend(
|
1005 | 1017 | self.model_config.get_head_size(),
|
1006 | 1018 | self.model_config.get_sliding_window(),
|
1007 | 1019 | self.model_config.dtype,
|
1008 | 1020 | self.kv_cache_dtype,
|
1009 | 1021 | self.block_size,
|
1010 | 1022 | self.model_config.is_attention_free,
|
1011 |
| - ) |
1012 |
| - self.attn_state = self.attn_backend.get_state_cls()( |
1013 |
| - weakref.proxy(self)) |
| 1023 | + ) if needs_attn_backend else None |
| 1024 | + if self.attn_backend: |
| 1025 | + self.attn_state = self.attn_backend.get_state_cls()( |
| 1026 | + weakref.proxy(self)) |
| 1027 | + else: |
| 1028 | + self.attn_state = CommonAttentionState(weakref.proxy(self)) |
1014 | 1029 |
|
1015 | 1030 | # Multi-modal data support
|
1016 | 1031 | self.input_registry = input_registry
|
|
0 commit comments