Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2a21f14
[hybrid kv] init support of hybrid kv
MengqingCao Sep 11, 2025
90dc42f
update
MengqingCao Sep 11, 2025
5b81e48
init kv cache pass
MengqingCao Sep 11, 2025
c8b02fc
update
MengqingCao Sep 12, 2025
82c472e
GDN support
Angazenn Sep 12, 2025
4fe8cf9
Update qwen3 moe
wangxiyuan Sep 12, 2025
648a3fc
fix registry bugs && remove print
Angazenn Sep 12, 2025
ea89e5d
repeat qk
Angazenn Sep 12, 2025
9832095
fix recurrent_gated_delta_rule
Angazenn Sep 12, 2025
47ef6e6
fix
Angazenn Sep 12, 2025
fd3e0dd
fix seq_lens
Angazenn Sep 12, 2025
0a2beb4
local
Sep 13, 2025
f62dda3
local 2
zzzzwwjj Sep 13, 2025
98a3dbb
local 3
zzzzwwjj Sep 14, 2025
4c9d94f
add vllm patch
wangxiyuan Sep 14, 2025
c0b8af7
fix may_reinitialize_input_batch bug
wangxiyuan Sep 14, 2025
c78e773
revert unnecessary
wangxiyuan Sep 14, 2025
8214f81
fix lint
wangxiyuan Sep 14, 2025
f57e114
fix lint
wangxiyuan Sep 14, 2025
1e79ee1
fix lint
wangxiyuan Sep 14, 2025
1b103cb
fix lint
wangxiyuan Sep 14, 2025
136583f
fix patch
wangxiyuan Sep 14, 2025
c1a99c6
fix lint
wangxiyuan Sep 14, 2025
cbc577e
fix torchair
wangxiyuan Sep 14, 2025
e92d690
fix lint
wangxiyuan Sep 14, 2025
2e58820
fix lint
wangxiyuan Sep 14, 2025
9e788d6
fix ut and e2e
wangxiyuan Sep 14, 2025
87c7822
fix ut and add fla ops
wangxiyuan Sep 14, 2025
b25d549
fix attention get_supported_block_size error
wangxiyuan Sep 14, 2025
384f36c
fix kv cache blcok shape
wangxiyuan Sep 14, 2025
b4a3566
fix deepseek oom
wangxiyuan Sep 14, 2025
706348d
fix deepseek oom
wangxiyuan Sep 14, 2025
252ddb2
fix deepseek func
wangxiyuan Sep 14, 2025
8bee3d7
fix ds
wangxiyuan Sep 14, 2025
88a689d
fix deepseek
wangxiyuan Sep 14, 2025
81bb245
fix deepseek mla
wangxiyuan Sep 14, 2025
af06aed
remove patch way
wangxiyuan Sep 15, 2025
63265f6
make code clean
wangxiyuan Sep 15, 2025
cd2b4b8
disable sharing kvcache in group && modify _convert_physical_to_logic…
MengqingCao Sep 15, 2025
12a0c58
[bugfix] fix torchair and mtp functionality
linfeng-yuan Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def setUp(self):
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config,
self.mock_device)

def test_reorder_batch(self):
Expand Down Expand Up @@ -105,14 +106,16 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache)
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)

mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor

self.builder.build(common_attn_metadata, mock_model)
self.builder.build(1, common_attn_metadata, mock_model)

@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
Expand All @@ -136,7 +139,9 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)

mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0
Expand All @@ -146,7 +151,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor

self.builder.build(common_attn_metadata, mock_model)
self.builder.build(1, common_attn_metadata, mock_model)

@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
Expand All @@ -165,10 +170,12 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
mock_model = MagicMock()

self.builder.build(common_attn_metadata, mock_model)
self.builder.build(1, common_attn_metadata, mock_model)


class TestAscendAttentionBackendImpl(TestBase):
Expand Down
6 changes: 4 additions & 2 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def test_ascend_mla_metadata_builder_default(self):
ascend_config = MagicMock()
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)

self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
Expand All @@ -209,7 +210,8 @@ def test_reorder_batch(self):

with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
builder.decode_threshold = 1

input_batch = MagicMock()
Expand Down
28 changes: 20 additions & 8 deletions tests/ut/torchair/test_torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def test_ascend_mla_metadata_builder_default(self):
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)

self.assertEqual(builder.block_size,
Expand All @@ -216,7 +217,8 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True

builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)

input_batch = MagicMock()
Expand Down Expand Up @@ -252,7 +254,8 @@ def test_reorder_batch_without_torchair_graph(self):

with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)

input_batch = MagicMock()
Expand Down Expand Up @@ -285,7 +288,8 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

Expand All @@ -305,7 +309,8 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

Expand All @@ -326,7 +331,8 @@ def test_get_graph_runner_block_tables_from_numpy(self,
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)

block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
Expand All @@ -352,6 +358,8 @@ def test_build_dummy(self, mock_ascend_config):
mock_device = 'cpu'

builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
Expand Down Expand Up @@ -417,6 +425,8 @@ def test_build_decode(self, mock_ascend_config):
model.model = MagicMock(spec=nn.Module)

builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
Expand All @@ -442,9 +452,11 @@ def test_build_decode(self, mock_ascend_config):
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)

metadata = builder.build(common_attn_metadata, model)
metadata = builder.build(1, common_attn_metadata, model)

self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/worker/test_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable

from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch

VOCAB_SIZE = 1024
Expand Down
25 changes: 16 additions & 9 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Tuple, Type
from typing import ClassVar, List, Optional, Tuple, Type

import torch
import torch.nn as nn
Expand All @@ -32,12 +32,12 @@
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
from vllm_ascend.worker.npu_input_batch import InputBatch


def wait_for_kv_layer_from_connector(layer_name: str):
Expand Down Expand Up @@ -145,6 +145,10 @@ def copy_blocks(
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]

@staticmethod
def get_supported_block_size() -> list[int]:
return [64]


class AscendAttentionState(Enum):
PrefillNoCache = 0
Expand Down Expand Up @@ -193,24 +197,29 @@ class AscendMetadata:


class AscendAttentionMetadataBuilder:
reorder_batch_threshold: ClassVar[int] = 1

def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
vllm_config.cache_config.block_size)
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0])

def reorder_batch(self, input_batch: "InputBatch",
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False

def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
Expand All @@ -219,11 +228,7 @@ def build(
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]

block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_table[:num_reqs])

query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
Expand Down Expand Up @@ -574,6 +579,8 @@ def unified_ascend_attention_with_output(
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class AscendMLAMetadataBuilder:

# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
Expand Down Expand Up @@ -265,6 +267,7 @@ def reorder_batch(self, input_batch: "InputBatch",

def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAMetadata:
Expand Down
7 changes: 7 additions & 0 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ class AscendCommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""

seq_lens: torch.Tensor
"""same to seq_lens_cpu, for compatibility with some new attn metadata
(such as GDN)."""

num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""

num_reqs: int
"""Number of requests"""
num_actual_tokens: int
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def register_model():
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM")
9 changes: 7 additions & 2 deletions vllm_ascend/models/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,13 @@ def forward(
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[
self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
forward_context.attn_metadata,
need_gather_q_kv, output)
attn_metadata, need_gather_q_kv,
output)
output = output.view(-1, output_shape[-1])
return output
Loading
Loading