Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/ut/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def test_get_attn_backend_cls_use_v1_and_torchair(self,
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
)

@patch('vllm_ascend.platform.get_ascend_config')
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def build(self,
num_actual_tokens,
max_query_len,
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False):
is_only_prefill: bool = False,
*args,
**kwargs):

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_attn_backend_cls(cls,
if use_mla:
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
elif use_torchair:
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
return "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
else:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
import numpy as np
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
from vllm.v1.core.sched.output import SchedulerOutput

from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID

from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d)
from vllm_ascend.worker.npu_input_batch import InputBatch


class AscendAttentionTorchairBackend(AttentionBackend):
class AscendAttentionTorchairBackend(AscendAttentionBackend):
accept_output_buffer: bool = True

@staticmethod
Expand All @@ -47,10 +48,6 @@ def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
return AscendTorchairMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
return AscendAttentionTorchairMetadataBuilder
Expand All @@ -73,36 +70,6 @@ def get_bsh_kv_cache_shape(
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)

@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]

dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]

for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]


@dataclass
class AscendDecodeMetadata:
Expand All @@ -117,40 +84,15 @@ class AscendDecodeMetadata:


@dataclass
class AscendTorchairMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: torch.Tensor
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
class AscendTorchairMetadata(AscendMetadata):

decode: Optional[AscendDecodeMetadata] = None

enable_dbo_across_dp: bool = False


class AscendAttentionTorchairMetadataBuilder:
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):

def __init__(self, runner):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
super().__init__(runner)

def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -222,11 +164,16 @@ def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1,
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False,
*args,
**kwargs):

if 'graph_pad_size' in kwargs:
graph_pad_size = kwargs['graph_pad_size']
else:
graph_pad_size = -1 # default value

device = self.runner.device

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MoECommMethod)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format)
Expand Down
Loading