Skip to content

Commit 27d9a92

Browse files
committed
Support pooling models
Signed-off-by: lianyibo <lianyibo1@kunlunit.com>
1 parent d01fd1d commit 27d9a92

File tree

4 files changed

+107
-24
lines changed

4 files changed

+107
-24
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import torch
1616

1717

18-
def _generate_attn_mask(max_seq_len, dtype):
18+
def _generate_attn_mask(max_seq_len, dtype, tril):
19+
if not tril:
20+
return torch.zeros(size=(max_seq_len, max_seq_len)).to(dtype)
1921
# Construct lower triangle matrix.
2022
mask_flag = torch.tril(
2123
torch.ones((max_seq_len, max_seq_len),
@@ -40,12 +42,13 @@ def __init__(
4042
max_seq_len: int,
4143
dtype: torch.dtype,
4244
device: torch.device = None,
45+
tril: bool = True,
4346
):
4447
# NOTE: The device argument specifies the target NPU
4548
# to be used for the newly added FIA operator.
4649
# Only pass this parameter when using the new FIA operator.
47-
48-
attn_mask = _generate_attn_mask(max_seq_len, dtype)
50+
self.tril = tril
51+
attn_mask = _generate_attn_mask(max_seq_len, dtype, self.tril)
4952

5053
self._seq_len_cached = attn_mask.shape[0]
5154
self.attn_mask_cache = attn_mask
@@ -103,6 +106,7 @@ def get_splitfuse_attn_mask(
103106
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
104107
if seqlen > self._seq_len_cached:
105108
self._seq_len_cached = seqlen
106-
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
109+
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype,
110+
self.tril)
107111
if self.attn_mask_cache.dtype != dtype:
108112
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

vllm_ascend/attention/attention_v1.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,27 @@ def __init__(
294294
self.key_cache = None
295295
self.value_cache = None
296296

297+
def _forward_encoder(
298+
self,
299+
query: torch.Tensor,
300+
key: torch.Tensor,
301+
value: torch.Tensor,
302+
attn_metadata: AscendMetadata,
303+
output: Optional[torch.Tensor] = None,
304+
num_tokens=0,
305+
) -> torch.Tensor:
306+
torch_npu._npu_flash_attention(query=query,
307+
key=key,
308+
value=value,
309+
mask=attn_metadata.attn_mask,
310+
seq_len=attn_metadata.seq_lens,
311+
scale_value=self.scale,
312+
num_heads=self.num_heads,
313+
num_kv_heads=self.num_kv_heads,
314+
out=output)
315+
assert output is not None
316+
return output[:num_tokens, :, :]
317+
297318
def _forward_prefill_no_cache(
298319
self,
299320
query: torch.Tensor,
@@ -577,10 +598,11 @@ def forward(
577598
num_actual_tokens = attn_metadata.num_actual_tokens
578599
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
579600
attn_type = self.attn_type
580-
if attn_type != AttentionType.DECODER:
581-
raise NotImplementedError("Encoder self-attention and "
582-
"encoder/decoder cross-attention "
583-
"are not implemented for "
601+
if attn_type not in [
602+
AttentionType.DECODER, AttentionType.ENCODER_ONLY
603+
]:
604+
raise NotImplementedError("Encoder/Decoder cross-attention "
605+
"is not implemented for "
584606
"PallasAttentionBackendImpl")
585607
# View q k v to BSH.
586608
query = query.view(-1, self.num_heads, self.head_size)
@@ -601,7 +623,11 @@ def forward(
601623
slot_indices=slots)
602624

603625
# V0-Style scheduler situation.
604-
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
626+
if attn_type == AttentionType.ENCODER_ONLY:
627+
output = self._forward_encoder(query, key, value,
628+
attn_metadata, output,
629+
num_tokens)
630+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
605631
output = self._forward_prefill_no_cache(
606632
query, key, value, attn_metadata, output, num_tokens)
607633
elif attn_metadata.attn_state == \

vllm_ascend/platform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
144144
structured_outputs_config.backend == "auto" and \
145145
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
146146
not scheduler_config.send_delta_data and \
147-
scheduler_config.policy == "fcfs":
147+
scheduler_config.policy == "fcfs" and \
148+
model_config.runner_type == "generate":
148149
ascend_scheduler_config.enabled = True
149150
chunked_prefill_enabled_in_ascend_scheduler = getattr(
150151
ascend_scheduler_config, "enable_chunked_prefill", False)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@
7575
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
7676
# yapf conflicts with isort for this block
7777
# yapf: disable
78-
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
79-
KVCacheConfig, KVCacheGroupSpec,
80-
KVCacheSpec, MambaSpec)
78+
from vllm.v1.kv_cache_interface import (AttentionSpec,
79+
EncoderOnlyAttentionSpec,
80+
FullAttentionSpec, KVCacheConfig,
81+
KVCacheGroupSpec, KVCacheSpec,
82+
MambaSpec)
8183
# yapf: enable
8284
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
8385
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
@@ -317,10 +319,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
317319
if torch.version.cann.startswith("8.3"):
318320
self.attn_mask_builder = AttentionMaskBuilder(
319321
self.scheduler_config.max_num_batched_tokens, self.dtype,
320-
self.device)
322+
self.device, self.model_config.runner_type == "generate")
321323
else:
322324
self.attn_mask_builder = AttentionMaskBuilder(
323-
self.model_config.max_model_len, self.dtype)
325+
self.model_config.max_model_len,
326+
self.dtype,
327+
tril=self.model_config.runner_type == "generate")
324328

325329
# Set up speculative decoding.
326330
self.spec_attn_mask = None
@@ -1477,14 +1481,29 @@ def _prepare_inputs(
14771481
# in the same group share the same metadata.
14781482
for kv_cache_group_id, kv_cache_group_spec in enumerate(
14791483
self.kv_cache_config.kv_cache_groups):
1480-
blk_table = self.input_batch.block_table[kv_cache_group_id]
1481-
blk_table_tensor = blk_table.get_device_tensor()
1482-
slot_mapping = blk_table.slot_mapping_cpu[:
1483-
total_num_scheduled_tokens]
1484-
self.slot_mapping[:total_num_scheduled_tokens].copy_(
1485-
slot_mapping[:total_num_scheduled_tokens],
1486-
non_blocking=True,
1487-
)
1484+
if isinstance(kv_cache_group_spec.kv_cache_spec,
1485+
EncoderOnlyAttentionSpec):
1486+
# Encoder-only layers do not have KV cache, so we need to
1487+
# create a dummy block table and slot mapping for them.
1488+
blk_table_tensor = torch.zeros(
1489+
(num_reqs, 1),
1490+
dtype=torch.int32,
1491+
device=self.device,
1492+
)
1493+
slot_mapping = torch.zeros(
1494+
(total_num_scheduled_tokens, ),
1495+
dtype=torch.int64,
1496+
device=self.device,
1497+
)
1498+
else:
1499+
blk_table = self.input_batch.block_table[kv_cache_group_id]
1500+
blk_table_tensor = blk_table.get_device_tensor()
1501+
slot_mapping = blk_table.slot_mapping_cpu[:
1502+
total_num_scheduled_tokens]
1503+
self.slot_mapping[:total_num_scheduled_tokens].copy_(
1504+
slot_mapping[:total_num_scheduled_tokens],
1505+
non_blocking=True,
1506+
)
14881507

14891508
# Make AscendCommonAttentionMetadata
14901509
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1533,6 +1552,11 @@ def _prepare_inputs(
15331552
common_prefix_len=common_prefix_len,
15341553
common_attn_metadata=common_attn_metadata,
15351554
**extra_attn_metadata_args)
1555+
elif self.model_config.runner_type == "pooling":
1556+
attn_metadata_i = builder.build(
1557+
common_prefix_len=common_prefix_len,
1558+
common_attn_metadata=common_attn_metadata,
1559+
**extra_attn_metadata_args)
15361560
else:
15371561
attn_metadata_i = builder.build(
15381562
common_prefix_len=common_prefix_len,
@@ -2639,6 +2663,33 @@ def _convert_torch_format(self, tensor):
26392663
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
26402664
return tensor
26412665

2666+
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
2667+
"""
2668+
Add encoder-only layers to the KV cache config.
2669+
"""
2670+
block_size = self.vllm_config.cache_config.block_size
2671+
use_mla = self.vllm_config.model_config.use_mla
2672+
encoder_only_attn_specs: dict[AttentionSpec,
2673+
list[str]] = defaultdict(list)
2674+
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
2675+
for layer_name, attn_module in attn_layers.items():
2676+
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
2677+
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
2678+
block_size=block_size,
2679+
num_kv_heads=attn_module.num_kv_heads,
2680+
head_size=attn_module.head_size,
2681+
dtype=self.kv_cache_dtype,
2682+
use_mla=use_mla)
2683+
encoder_only_attn_specs[attn_spec].append(layer_name)
2684+
self.runner_only_attn_layers.add(layer_name)
2685+
if len(encoder_only_attn_specs) > 0:
2686+
assert len(
2687+
encoder_only_attn_specs
2688+
) == 1, "Only support one encoder-only attention spec now"
2689+
spec, layer_names = encoder_only_attn_specs.popitem()
2690+
self.kv_cache_config.kv_cache_groups.append(
2691+
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
2692+
26422693
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26432694
"""
26442695
Initialize KV cache based on `kv_cache_config`.
@@ -2648,9 +2699,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26482699
"""
26492700
kv_cache_config = deepcopy(kv_cache_config)
26502701
self.kv_cache_config = kv_cache_config
2702+
self.may_reinitialize_input_batch(kv_cache_config)
2703+
self.may_add_encoder_only_layers_to_kv_cache_config()
26512704
self.initialize_attn_backend(kv_cache_config)
26522705
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
2653-
self.may_reinitialize_input_batch(kv_cache_config)
26542706

26552707
if self.model_config.is_deepseek_mla:
26562708
kv_caches = self.initialize_kv_cache_tensors_deepseek(

0 commit comments

Comments
 (0)