Skip to content

Commit 9ccacde

Browse files
committed
feat: support v1 engine on 310P
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
1 parent 7244ebb commit 9ccacde

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
34+
nd_to_nz_2d, nd_to_nz_spec)
3335

3436

3537
class AscendAttentionBackend(AttentionBackend):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264
num_kv_heads: int,
6365
head_size: int,
6466
) -> Tuple[int, ...]:
67+
if is_310p():
68+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
69+
16)
6570
return (2, num_blocks, block_size, num_kv_heads, head_size)
6671

6772
@staticmethod
@@ -166,6 +171,16 @@ def build(self,
166171
query_start_loc = query_start_loc_cpu.to(self.runner.device,
167172
non_blocking=True)
168173

174+
if is_310p():
175+
if attn_state == AscendAttentionState.PrefillNoCache:
176+
mask_nz = nd_to_nz_2d(attn_mask)
177+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
178+
ACL_FORMAT_FRACTAL_NZ)
179+
elif attn_state == AscendAttentionState.ChunkedPrefill:
180+
mask_nz = nd_to_nz_spec(attn_mask)
181+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
182+
ACL_FORMAT_FRACTAL_NZ)
183+
169184
attn_metadata = AscendMetadata(
170185
num_actual_tokens=num_actual_tokens,
171186
block_tables=block_table,
@@ -249,6 +264,7 @@ def forward(
249264
self.head_size,
250265
dtype=query.dtype,
251266
device=query.device)
267+
ori_output = output
252268
if trace_flag:
253269
torch.ops.vllm.unified_ascend_attention_with_output(
254270
query=query,
@@ -293,6 +309,18 @@ def forward(
293309
assert attn_metadata is not None
294310
assert attn_metadata.attn_mask is not None
295311
mask = attn_metadata.attn_mask
312+
if is_310p():
313+
# align q k v output tensors
314+
query = aligned_16(query)
315+
key = aligned_16(key)
316+
value = aligned_16(value)
317+
output = aligned_16(output)
318+
319+
# do reformat in case of broadcasted tensors
320+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
321+
mask = torch_npu.npu_format_cast(mask.contiguous(),
322+
ACL_FORMAT_FRACTAL_NZ)
323+
296324
torch_npu._npu_flash_attention(query=query,
297325
key=key,
298326
value=value,
@@ -302,6 +330,7 @@ def forward(
302330
num_heads=self.num_heads,
303331
num_kv_heads=self.num_kv_heads,
304332
out=output)
333+
output = output[:num_tokens, :, :]
305334
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
306335
assert attn_metadata is not None
307336
assert attn_metadata.attn_mask is not None
@@ -319,6 +348,10 @@ def forward(
319348
scale_value=self.scale,
320349
out=output)
321350
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
351+
if is_310p():
352+
# # seq_lens_tensor needs to be transferred to the device for 310P
353+
attn_metadata.seq_lens = \
354+
attn_metadata.seq_lens.to(device=query.device)
322355
torch_npu._npu_paged_attention(
323356
query=query,
324357
key_cache=self.key_cache,
@@ -352,6 +385,14 @@ def forward(
352385
self.scale, None, True)
353386
else:
354387
# use paged attention
388+
assert attn_metadata is not None
389+
assert attn_metadata.attn_mask is not None
390+
if is_310p():
391+
# do reformat in case of broadcasted tensors
392+
attn_metadata.attn_mask = \
393+
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
394+
attn_metadata.seq_lens = \
395+
attn_metadata.seq_lens.to(device=query.device)
355396
torch_npu._npu_paged_attention_splitfuse(
356397
query=query,
357398
key_cache=self.key_cache,
@@ -364,6 +405,10 @@ def forward(
364405
num_heads=self.num_heads,
365406
scale_value=self.scale,
366407
out=output)
408+
409+
# to make in-place change to the output tensor
410+
if not id(ori_output) == id(output):
411+
ori_output[:, :, :] = output[:num_tokens, :, :]
367412
return output.view(num_tokens, self.hidden_size)
368413

369414

vllm_ascend/platform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
import vllm_ascend.envs as ascend_envs
2929
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
30-
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
30+
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
31+
update_aclgraph_sizes)
3132

3233
CUSTOM_OP_ENABLED = False
3334
try:
@@ -202,8 +203,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
202203
cache_config.block_size = 128
203204

204205
if envs.VLLM_USE_V1:
205-
# Activate custom ops for v1.
206-
compilation_config.custom_ops = ["all"]
206+
# Activate custom ops for v1, except on 310P
207+
if not is_310p():
208+
compilation_config.custom_ops = ["all"]
207209

208210
# If ascend_scheduler_config is enabled,
209211
# extents original scheduler_config to use AscendScheduler.

vllm_ascend/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,21 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
104104
2).contiguous()
105105

106106

107+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
108+
num_tokens = mask_tensor.shape[0]
109+
max_seq_len = mask_tensor.shape[1]
110+
111+
tokens_pad = (num_tokens + 15) // 16 * 16
112+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
113+
114+
mask_tensor_pad = \
115+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
116+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
117+
mask = mask_tensor_pad.reshape(
118+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
119+
return mask
120+
121+
107122
def aligned_16(tensor: torch.Tensor):
108123
"""Aligned tensor for 310P"""
109124

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@
6868
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
6969
from vllm_ascend.platform import NPUPlatform
7070
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
71-
from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is
71+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
72+
ProfileExecuteDuration, is_310p,
73+
vllm_version_is)
7274
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
7375

7476
if TYPE_CHECKING:
@@ -1330,6 +1332,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
13301332
cache size of each layer
13311333
"""
13321334
import torch_npu
1335+
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
1336+
) else ACL_FORMAT_FRACTAL_ND
13331337
kv_caches: Dict[str, torch.Tensor] = {}
13341338

13351339
# Remove this after we drop 0.9.0 support
@@ -1404,13 +1408,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
14041408
device=self.device)
14051409
kv_caches[layer_name] = (layer_kv_cache_nope,
14061410
layer_kv_cache_pe)
1407-
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
1408-
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
1411+
kv_caches[layer_name][0] = \
1412+
torch_npu.npu_format_cast(kv_caches[layer_name][0], acl_format)
1413+
kv_caches[layer_name][1] = \
1414+
torch_npu.npu_format_cast(kv_caches[layer_name][1], acl_format)
14091415
else:
14101416
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
14111417
dtype=dtype,
14121418
device=self.device)
1413-
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
1419+
kv_caches[layer_name] = \
1420+
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
14141421
else:
14151422
# TODO: add new branches when introducing more types of
14161423
# KV cache specs.

0 commit comments

Comments
 (0)