Skip to content

Commit 2d88586

Browse files
authored
[KVCache][Bugfix] Fix kv cache initialization error of attention layer (vllm-project#3113)
### What this PR does / why we need it? Fixes vllm-project#3096 1. Fix kv cache initialization error of attention layer. There are some models with layer name like `attn.attn`, instead of `self_attn`, but the initialization of kv cache tensors only check for `self_attn` and `attn.attn`, which leding to the error `AssertionError: Some layers are not correctly initialized` 2. Set the default value of input arg `sampling_metadata` in `compute_logits` for the modeling files in vllm-ascend. Thus fixing the error `Qwen3NextForCausalLM.compute_logits() missing 1 required positional argument: 'sampling_metadata'` ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? test locally with internlm - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@5aeb925 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 6aa4253 commit 2d88586

File tree

6 files changed

+10
-8
lines changed

6 files changed

+10
-8
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(
166166
def compute_logits(
167167
self,
168168
hidden_states: torch.Tensor,
169-
sampling_metadata, # type: ignore
169+
sampling_metadata=None, # type: ignore
170170
spec_step_idx: int = 0,
171171
) -> torch.Tensor:
172172
current_step_idx = (spec_step_idx % self.num_mtp_layers)

vllm_ascend/models/qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def get_mamba_state_shape_from_config(
986986
def compute_logits(
987987
self,
988988
hidden_states: torch.Tensor,
989-
sampling_metadata, # type: ignore
989+
sampling_metadata=None, # type: ignore
990990
) -> Optional[torch.Tensor]:
991991
return self.logits_processor(self.lm_head, hidden_states,
992992
sampling_metadata)

vllm_ascend/torchair/models/qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def forward(
344344
def compute_logits(
345345
self,
346346
hidden_states: torch.Tensor,
347-
sampling_metadata, # type: ignore
347+
sampling_metadata=None, # type: ignore
348348
) -> Optional[torch.Tensor]:
349349
logits = self.logits_processor(self.lm_head, hidden_states,
350350
sampling_metadata)

vllm_ascend/torchair/models/torchair_deepseek_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward(
170170
def compute_logits(
171171
self,
172172
hidden_states: torch.Tensor,
173-
sampling_metadata, # type: ignore
173+
sampling_metadata=None, # type: ignore
174174
spec_step_idx: int = 0,
175175
) -> torch.Tensor:
176176
current_step_idx = (spec_step_idx % self.num_mtp_layers)

vllm_ascend/torchair/models/torchair_pangu_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def forward(
936936
def compute_logits(
937937
self,
938938
hidden_states: torch.Tensor,
939-
sampling_metadata, # type: ignore
939+
sampling_metadata=None, # type: ignore
940940
) -> Optional[torch.Tensor]:
941941
logits = self.logits_processor(self.lm_head, hidden_states,
942942
sampling_metadata)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2784,9 +2784,10 @@ def initialize_kv_cache_tensors(
27842784
for idx in range(len(kv_cache_tensor.shared_by)):
27852785
layer_name = kv_cache_tensor.shared_by[idx]
27862786
if "linear_attn" in layer_name:
2787+
# for mamba linear attention
27872788
for layer_name_inner in kv_cache_tensor.shared_by:
2788-
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
2789-
):
2789+
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
2790+
layer_name_inner in kv_cache_raw_tensors.keys():
27902791
continue
27912792
if self.vllm_config.kv_transfer_config is None:
27922793
tensor = torch.zeros(kv_cache_tensor.size,
@@ -2800,7 +2801,8 @@ def initialize_kv_cache_tensors(
28002801
tensor = self._align_memory(
28012802
tensor, alignment)[:kv_cache_tensor.size]
28022803
kv_cache_raw_tensors[layer_name_inner] = tensor
2803-
elif "self_attn" in layer_name:
2804+
elif "attn" in layer_name:
2805+
# for other attentions, e.g., self_attn, sliding window attn
28042806
if self.vllm_config.kv_transfer_config is None:
28052807
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
28062808
dtype=torch.int8,

0 commit comments

Comments
 (0)