diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index b742897098..94a34e9b8e 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -72,7 +72,8 @@ def setUp(self): self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.cache_config.block_size = 64 self.mock_device = 'cpu:0' - self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, + self.builder = AscendAttentionMetadataBuilder(None, None, + self.mock_vllm_config, self.mock_device) def test_reorder_batch(self): @@ -105,14 +106,16 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, positions=torch.tensor([10, 10]), attn_mask=torch.ones((10, 10)), spec_attn_mask=None, - attn_state=AscendAttentionState.PrefillNoCache) + attn_state=AscendAttentionState.PrefillNoCache, + num_computed_tokens_cpu=None, + seq_lens=None) mock_nz_tensor = MagicMock() mock_model = MagicMock() mock_nd_to_nz_2d.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('torch_npu.npu_format_cast') @@ -136,7 +139,9 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state, positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) mock_ascend_attention_state = MagicMock() mock_ascend_attention_state.PrefillNoCache = 0 @@ -146,7 +151,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state, mock_nd_to_nz_spec.return_value = mock_nz_tensor mock_npu_format_cast.return_value = mock_nz_tensor - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @@ -165,10 +170,12 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) mock_model = MagicMock() - self.builder.build(common_attn_metadata, mock_model) + self.builder.build(1, common_attn_metadata, mock_model) class TestAscendAttentionBackendImpl(TestBase): diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 63605042e2..a1df85b950 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -189,7 +189,8 @@ def test_ascend_mla_metadata_builder_default(self): ascend_config = MagicMock() with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) self.assertEqual(builder.block_size, mock_vllm_config.cache_config.block_size) @@ -209,7 +210,8 @@ def test_reorder_batch(self): with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): - builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) builder.decode_threshold = 1 input_batch = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 6ee983a071..0990752984 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -195,7 +195,8 @@ def test_ascend_mla_metadata_builder_default(self): ascend_config.torchair_graph_config.enabled = True with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) self.assertEqual(builder.block_size, @@ -216,7 +217,8 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config): ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) input_batch = MagicMock() @@ -252,7 +254,8 @@ def test_reorder_batch_without_torchair_graph(self): with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) input_batch = MagicMock() @@ -285,7 +288,8 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -305,7 +309,8 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -326,7 +331,8 @@ def test_get_graph_runner_block_tables_from_numpy(self, mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' - builder = AscendMLATorchairMetadataBuilder(mock_vllm_config, + builder = AscendMLATorchairMetadataBuilder(None, None, + mock_vllm_config, mock_device) block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) @@ -352,6 +358,8 @@ def test_build_dummy(self, mock_ascend_config): mock_device = 'cpu' builder = AscendMLATorchairMetadataBuilder( + None, + None, mock_vllm_config, mock_device, metadata_cls=AscendMLATorchairMetadata) @@ -417,6 +425,8 @@ def test_build_decode(self, mock_ascend_config): model.model = MagicMock(spec=nn.Module) builder = AscendMLATorchairMetadataBuilder( + None, + None, mock_vllm_config, mock_device, metadata_cls=AscendMLATorchairMetadata) @@ -442,9 +452,11 @@ def test_build_decode(self, mock_ascend_config): positions=torch.tensor([1, 1]), attn_mask=torch.ones((15, 15)), spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill) + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) - metadata = builder.build(common_attn_metadata, model) + metadata = builder.build(1, common_attn_metadata, model) self.assertIsInstance(metadata, AscendMLATorchairMetadata) self.assertEqual(metadata.num_input_tokens, 0) diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py index a72dbdc1d9..703098d2c6 100644 --- a/tests/ut/worker/test_input_batch.py +++ b/tests/ut/worker/test_input_batch.py @@ -24,8 +24,8 @@ from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable +from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch VOCAB_SIZE = 1024 diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index adc6c01f69..e0905faecb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Tuple, Type +from typing import ClassVar, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -32,12 +32,12 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) -from vllm_ascend.worker.npu_input_batch import InputBatch def wait_for_kv_layer_from_connector(layer_name: str): @@ -145,6 +145,10 @@ def copy_blocks( key_caches[dst_indices] = key_caches[src_indices] value_caches[dst_indices] = value_caches[src_indices] + @staticmethod + def get_supported_block_size() -> list[int]: + return [64] + class AscendAttentionState(Enum): PrefillNoCache = 0 @@ -193,24 +197,29 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: + reorder_batch_threshold: ClassVar[int] = 1 def __init__( self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device - self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, - vllm_config.cache_config.block_size) + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + AscendAttentionBackend.get_supported_block_size()[0]) - def reorder_batch(self, input_batch: "InputBatch", + def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: return False def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ): @@ -219,11 +228,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] - block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :self.max_num_blocks_per_req] = ( - block_table[:num_reqs]) - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping_cpu[: @@ -574,6 +579,8 @@ def unified_ascend_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0031513742..84cf1a5471 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -171,6 +171,8 @@ class AscendMLAMetadataBuilder: # _attn_mask_builder = None def __init__(self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[AscendMLAMetadata] = None): @@ -265,6 +267,7 @@ def reorder_batch(self, input_batch: "InputBatch", def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ) -> AscendMLAMetadata: diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 2ef537ff0c..65af109799 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -21,6 +21,13 @@ class AscendCommonAttentionMetadata: """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + seq_lens: torch.Tensor + """same to seq_lens_cpu, for compatibility with some new attn metadata + (such as GDN).""" + + num_computed_tokens_cpu: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + num_reqs: int """Number of requests""" num_actual_tokens: int diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 7899fc1a4b..ac8bfbf2ed 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -53,3 +53,6 @@ def register_model(): "PanguProMoEForCausalLM", "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" ) + ModelRegistry.register_model( + "Qwen3NextForCausalLM", + "vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM") diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index f28e485933..0f9adf4ff4 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -132,8 +132,13 @@ def forward( output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[ + self.mla_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata output = self.mla_attn.impl.forward(hidden_states, kv_cache, - forward_context.attn_metadata, - need_gather_q_kv, output) + attn_metadata, need_gather_q_kv, + output) output = output.view(-1, output_shape[-1]) return output diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py new file mode 100644 index 0000000000..a94e72d8ef --- /dev/null +++ b/vllm_ascend/models/qwen3_next.py @@ -0,0 +1,1361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# mypy: ignore-errors +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from vllm import envs +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import \ + GemmaRMSNorm as Qwen3NextRMSNorm +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import \ + mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import \ + GPTQMarlinConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + MixtureOfExperts, + SupportsLoRA, SupportsPP) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, + causal_conv1d_update_npu) +from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating +from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule + + +class Qwen3NextSparseMoeBlock(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE(num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -( + (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = (torch.zeros(batch_size, sequence_length, + k_head_dim, v_head_dim).to(value) if + initial_state is None else initial_state.to(value)) + + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=1) + + # for each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * + decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * + (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2) @ v_new) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out.transpose(1, + 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + use_v1=True) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + (self.head_k_dim + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // + self.num_k_heads), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, + split_arg_list_qkvz, + dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), + (query, key)) + value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + cache_params: Optional[MambaCacheParams] = None, + ): + return torch.ops.vllm.gdn_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.1: process the mutli-query part + # if spec_sequence_masks is not None: + # mixed_qkv_spec = mixed_qkv_spec.view( + # attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) + # mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') + # mixed_qkv_spec = causal_conv1d_update( + # mixed_qkv_spec, + # conv_state, + # conv_weights, + # self.conv1d.bias, + # self.activation, + # conv_state_indices=spec_state_indices_tensor[:, 0] + # [:attn_metadata.num_spec_decodes], + # num_accepted_tokens=num_accepted_tokens, + # validate_data=False, + # ) + # mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update_npu( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + # validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + + batch_size = initial_state.shape[0] + core_attn_out = [] + last_recurrent_state = [] + + for b_idx in range(batch_size): + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + cur_q = query_non_spec[:, start:end, ...] + cur_k = key_non_spec[:, start:end, ...] + cur_v = value_non_spec[:, start:end, ...] + cur_g = g_non_spec[:, start:end, ...] + cur_b = beta_non_spec[:, start:end, ...] + cur_state = initial_state[b_idx].unsqueeze(0) + + ( + cur_core_attn_out_non_spec, + cur_last_recurrent_state, + ) = torch_chunk_gated_delta_rule( + query=cur_q, + key=cur_k, + value=cur_v, + g=cur_g, + beta=cur_b, + initial_state=cur_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out.append(cur_core_attn_out_non_spec) + last_recurrent_state.append(cur_last_recurrent_state) + + tar_dtype = core_attn_out[0].dtype + tar_device = core_attn_out[0].device + tar_shape = list(core_attn_out[0].shape) + tar_shape[1] = non_spec_query_start_loc[-1] + core_attn_out_non_spec = torch.empty(tar_shape, + dtype=tar_dtype, + device=tar_device) + for b_idx in range(batch_size): + cur_core_attn_out = core_attn_out[b_idx] + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out + last_recurrent_state = torch.cat(last_recurrent_state, dim=0) + + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class Qwen3NextAttention(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + partial_rotary_factor=config.partial_rotary_factor, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": + self.dual_chunk_attention_config, + } if self.dual_chunk_attention_config else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_type: str, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.config = config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + assert len(hidden_states.shape) == len( + self.ffn_layer_scale.shape + ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + speculative_config = vllm_config.speculative_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + config, + layer_type=config.layer_types[extract_layer_index(prefix)], + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("in_proj", "in_proj_qkvz", 0), + ("in_proj", "in_proj_ba", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + MixtureOfExperts, IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = (vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + use_v1=True) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def gdn_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="gdn_attention", + op_func=gdn_attention, + mutates_args=["output"], + fake_impl=gdn_attention_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py new file mode 100644 index 0000000000..68790b592d --- /dev/null +++ b/vllm_ascend/ops/casual_conv1d.py @@ -0,0 +1,597 @@ +# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# mypy: ignore-errors + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]].unsqueeze(0), + initial_states=conv_states[cache_indices[i]] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None, + cache_seqlens=None, + conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state[conv_state_indices], x], dim=-1).to( + weight.dtype) # (batch, dim, state_len + seqlen) + conv_state[conv_state_indices] = x_new[:, :, -state_len:] + else: + width_idx = torch.arange( + -(width - 1), 0, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand( + -1, dim, -1)) + x_new = torch.cat([conv_state.gather(2, width_idx), x], + dim=-1).to(weight.dtype) + copy_idx = torch.arange( + seqlen, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, + state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, + groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = (conv_state_base + + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + # mask_1d = (idx_token < seqlen) & ( + # idx_feats < dim + # ) # token-index # feature-index + maskL = idx_feats < dim + maskR = tl.full(maskL.shape, False, tl.int1) + mask_1d = tl.where(idx_token < seqlen, maskL, maskR) + + o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + (idx_feats * stride_o_dim)) + + tl.store(o_ptrs, acc, mask=mask_1d) + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = (intermediate_conv_window_ptr + + conv_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + +def causal_conv1d_update_npu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + stride_state_indices = (conv_state_indices.stride(0) + if conv_state_indices is not None else 0) + state_len = width - 1 + (seqlen - 1) # effective state_len needed + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window + if intermediate_conv_window is not None else x, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=128, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py new file mode 100644 index 0000000000..df81dd5741 --- /dev/null +++ b/vllm_ascend/ops/fla.py @@ -0,0 +1,381 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. +# mypy: ignore-errors + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + + +def rms_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True, +): + dtype = x.dtype + #N = x.shape[-1] + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * + weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + if not is_rms_norm else None) + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.npu.device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + return y.reshape(x_shape_og) + + +def layernorm_fn( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, True) + + +class LayerNorm(torch.nn.Module): + + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) + + +class RMSNormGated(torch.nn.Module): + + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter( + torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return rmsnorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) + + +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/sigmoid_gating.py new file mode 100644 index 0000000000..d599287bef --- /dev/null +++ b/vllm_ascend/ops/sigmoid_gating.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors + +import os +from typing import Optional + +import torch +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + # p_q = q + (bos * H + i_h) * K + o_k + # p_k = k + (bos * H + i_h) * K + o_k + # p_v = v + (bos * HV + i_hv) * V + o_v + # if IS_BETA_HEADWISE: + # p_beta = beta + (bos * HV + i_hv) * V + o_v + # else: + # p_beta = beta + bos * HV + i_hv + # p_g = g + bos * HV + i_hv + # p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t + p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t + p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t + else: + p_beta = beta + bos * HV + i_hv + HV * i_t + p_g = g + bos * HV + i_hv + HV * i_t + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t + + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + # b_h *= tl.exp(b_g) + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + # p_q += H * K + # p_k += H * K + # p_o += HV * V + # p_v += HV * V + # p_g += HV + # p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # print("N: ", N) + # print("T: ", T) + # print("B: ", B) + # print("H: ", H) + # print("HV: ", HV) + # print("K: ", K) + # print("V: ", V) + # print("BK: ", BK) + # print("BV: ", BV) + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index 35ef14970a..45f1b62627 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -16,4 +16,5 @@ # import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa diff --git a/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py new file mode 100644 index 0000000000..d9ca8ff312 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_mamba_config.py @@ -0,0 +1,97 @@ +# mypy: ignore-errors +import vllm.model_executor.models.config +from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.config import MambaModelConfig +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + +@classmethod +def verify_and_update_config(cls, vllm_config) -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + logger = init_logger(__name__) + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + use_mla=model_config.use_mla).page_size_bytes + + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=model_config.max_model_len, + ).page_size_bytes + + block_alignment_bytes = 64 + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = block_alignment_bytes * cdiv( + mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # compute new attention page size + attn_page_size = \ + cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + +vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index e91f2cd5e1..b11c0c4cdb 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -365,3 +365,7 @@ def stateless_init_device_torch_dist_pg( pg._register_backend(device, backend_type, backend_class) return pg + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index cbbf19fcbd..f627f23c15 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -350,7 +350,8 @@ def _get_eagle_atten_dict( spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - ) + num_computed_tokens_cpu=None, + seq_lens=None) attn_metadata_i = self.runner.attn_metadata_builder.build( common_attn_metadata, self.runner.get_model()) for layer_name in kv_cache_group_spec.layer_names: @@ -436,7 +437,8 @@ def _propose( spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, - ) + num_computed_tokens_cpu=None, + seq_lens=None) # FIXME(woosuk): The below two ops cause synchronization. Optimize. attn_metadata = self.runner.attn_metadata_builder.build( common_attn_metadata, self.runner.model) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 5e56493495..35a34c7dc9 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,4 +1,5 @@ import types +from typing import Dict import torch import torch.nn as nn @@ -186,6 +187,8 @@ def generate_token_ids(self, hidden_states: torch.Tensor = None, attn_metadata=None, aux_hidden_states: torch.Tensor = None): + if attn_metadata is not None and isinstance(attn_metadata, Dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): if token_ids: @@ -379,9 +382,10 @@ def _propose( attn_state=self.runner.attn_state, graph_pad_size=self.runner.graph_pad_size, decode_token_per_req=self.runner.decode_token_per_req, - ) + num_computed_tokens_cpu=None, + seq_lens=None) attn_metadata = self.runner.attn_metadata_builder.build( - common_attn_metadata, self.runner.get_model()) + 0, common_attn_metadata, self.runner.get_model()) self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index d2443ad442..143fb93c6b 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -98,10 +98,12 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): def __init__( self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.max_num_blocks_per_req = cdiv( self.model_config.max_model_len, self.vllm_config.cache_config.block_size) @@ -171,6 +173,7 @@ def build_torchair_graph_dummy( def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ): diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 80ada4d04a..1c78045079 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -176,6 +176,8 @@ class AscendMLATorchairMetadataBuilder: # _attn_mask_builder = None def __init__(self, + kv_cache_spec, + layer_names, vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[AscendMLATorchairMetadata] = None): @@ -372,6 +374,7 @@ def build_torchair_graph_dummy( def build( self, + common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module, ) -> AscendMLATorchairMetadata: diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index d730c44bfd..a2b9313335 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -19,7 +19,7 @@ import math import types -from typing import Optional +from typing import Optional, Dict import torch import torch.distributed as dist @@ -50,6 +50,9 @@ class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): super().__init__(vllm_config, device) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + None, None, vllm_config, device) + ascend_config = get_ascend_config() self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore @@ -132,7 +135,8 @@ def _generate_dummy_run_hidden_states(self, with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds): - + if attn_metadata is not None and isinstance(attn_metadata, Dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] if not with_prefill: # Only mark static while compiling if is_torchair_compile: @@ -278,6 +282,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, input_ids, positions, intermediate_tensors, inputs_embeds): + if attn_metadata is not None and isinstance(attn_metadata, Dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] model_kwargs = { "kv_caches": self.kv_caches, "attn_metadata": attn_metadata diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py new file mode 100644 index 0000000000..e01e946991 --- /dev/null +++ b/vllm_ascend/worker/block_table.py @@ -0,0 +1,308 @@ +from typing import Optional, Union + +import numpy as np +import torch +from vllm.distributed import get_dcp_group +from vllm.utils import cdiv + + +class BlockTable: + + def __init__(self, + block_size: int, + max_num_reqs: int, + max_num_blocks_per_req: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + kernel_sizes: Union[list[int], None] = None): + self.max_num_reqs = max_num_reqs + self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens + self.pin_memory = pin_memory + self.device = device + self.physical_block_size = block_size + # If kernel_sizes is None or [0], use physical block size (no splitting) + if kernel_sizes is None or kernel_sizes == [0]: + self.block_size = block_size + self.logical_block_size = block_size + self.blocks_per_phys_block = 1 + self.use_hybrid_blocks = False + else: + # Find the first kernel size that divides physical_block_size evenly + selected_kernel_size = None + for kernel_size in kernel_sizes: + if kernel_size > 0 \ + and self.physical_block_size % kernel_size == 0: + selected_kernel_size = kernel_size + break + + if selected_kernel_size is None: + raise ValueError( + f"None of the kernel sizes {kernel_sizes} can divide " + f"physical block size {self.physical_block_size} evenly") + + self.block_size = selected_kernel_size + self.logical_block_size = selected_kernel_size + self.blocks_per_phys_block = (self.physical_block_size // + self.logical_block_size) + if self.blocks_per_phys_block > 1: + self.use_hybrid_blocks = True + else: + self.use_hybrid_blocks = False + + if self.use_hybrid_blocks: + logical_table_size = (max_num_blocks_per_req * + self.blocks_per_phys_block) + else: + logical_table_size = max_num_blocks_per_req + + self.block_table = torch.zeros( + (max_num_reqs, logical_table_size), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, logical_table_size), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + def append_row( + self, + block_ids, + row_idx: int, + ) -> None: + if not block_ids: + return + block_ids = np.array(block_ids) + if self.use_hybrid_blocks: + block_ids = self._convert_physical_to_logical_blocks(block_ids) + + num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] += num_blocks + + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // virtual_block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + + block_numbers = self.block_table_np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calculate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calculate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping_np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // self.block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + + block_numbers = self.block_table_np.ravel()[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) + + def commit_block_table(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def commit_slot_mapping(self, num_tokens: int) -> None: + self.slot_mapping[:num_tokens].copy_( + self.slot_mapping_cpu[:num_tokens], non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def _convert_physical_to_logical_blocks( + self, physical_blocks: np.ndarray) -> np.ndarray: + """Convert physical block IDs to logical block IDs.""" + if not self.use_hybrid_blocks: + return physical_blocks + + # Create logical block IDs by splitting each physical block + logical_blocks: list[int] = [] + for phys_block in physical_blocks: + # Convert physical block to multiple logical blocks + # Physical block 1 becomes logical blocks + # [1*split_ratio, 1*split_ratio+1, ...] + # But we need to account for the fact that block 0 is special + base_logical = (phys_block - 1) * self.blocks_per_phys_block + 1 + logical_blocks.extend( + range(base_logical, base_logical + self.blocks_per_phys_block)) + + return np.array(logical_blocks, dtype=np.int32) + + def get_device_tensor(self) -> torch.Tensor: + """Returns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np + + +class MultiGroupBlockTable: + """The BlockTables for each KV cache group.""" + + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + kernel_sizes: Optional[list[list[int]]] = None) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + + if kernel_sizes is None: + kernel_sizes = [[0]] * len(block_sizes) + # Ensure kernel_sizes matches block_sizes length + elif len(kernel_sizes) == 1 and len(block_sizes) > 1: + kernel_sizes = kernel_sizes * len(block_sizes) + elif len(kernel_sizes) != len(block_sizes): + raise ValueError( + f"kernel_sizes length ({len(kernel_sizes)}) must match " + f"block_sizes length ({len(block_sizes)})") + + # Use zip to pair block_sizes with kernel_sizes one-to-one + self.block_tables = [ + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device, kernel_size_list) + for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + ] + + def append_row(self, block_ids: tuple[list[int], ...], + row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) + + def move_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.move_row(src, tgt) + + def swap_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.swap_row(src, tgt) + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + for block_table in self.block_tables: + block_table.compute_slot_mapping(req_indices, positions) + + def commit_block_table(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit_block_table(num_reqs) + + def commit_slot_mapping(self, num_tokens: int) -> None: + for block_table in self.block_tables: + block_table.commit_slot_mapping(num_tokens) + + def clear(self) -> None: + for block_table in self.block_tables: + block_table.clear() + + def __getitem__(self, idx: int) -> "BlockTable": + """Returns the BlockTable for the i-th KV cache group.""" + return self.block_tables[idx] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b4261fa2be..6b9ecd7924 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -19,11 +19,14 @@ import copy import gc -import math +import itertools import time +from collections import defaultdict +from collections.abc import Iterator from contextlib import contextmanager, nullcontext +from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast import numpy as np import numpy.typing as npt @@ -33,10 +36,12 @@ import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -46,7 +51,8 @@ is_global_first_rank) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import supports_transcription @@ -59,28 +65,32 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LazyLoader, cdiv, is_pin_memory_available) + LazyLoader, cdiv, get_dtype_size, + is_pin_memory_available) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import \ + reorder_batch_to_split_decodes_and_prefills from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, MambaSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, +from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionState, - AscendMetadata) -from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.multistream.ms_split import compute_split_seq_index @@ -91,8 +101,6 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata -from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, get_ascend_soc_version, is_310p, @@ -241,14 +249,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): from vllm.v1.sample.sampler import Sampler self.sampler = Sampler() + self.reorder_batch_threshold: Optional[int] = None # Lazy initialization, these will be set after __init__ self.kv_caches: List[torch.Tensor] = [] + self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None + self.runner_only_attn_layers: set[str] = set() ascend_config = get_ascend_config() if ascend_config.ascend_scheduler_config.enabled: @@ -279,8 +290,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.model_config.is_attention_free, use_mla=self.model_config.use_mla, ) - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - vllm_config, device) + self.attn_mask_builder = AttentionMaskBuilder( self.model_config.max_model_len, self.dtype) @@ -412,6 +422,73 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.use_async_scheduling = self.scheduler_config.async_scheduling self.async_output_copy_stream = torch.npu.Stream() if \ self.use_async_scheduling else None + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, + kernel_block_sizes=None, + ) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + # Bfloat16 torch tensors cannot be directly cast to a numpy array, so + # if a bfloat16 buffer is needed without a corresponding numpy array, + # don't bother instantiating the numpy array. + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -611,7 +688,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() - + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() @@ -970,22 +1048,42 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, src=self.input_batch.prev_sampled_token_ids[ prev_common_req_indices_tensor, 0]) + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + + if self.reorder_batch_threshold is not None: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata, - AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int, - torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, - Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, + int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - self.attn_metadata_builder.reorder_batch(self.input_batch, - scheduler_output) # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) @@ -1088,9 +1186,6 @@ def _prepare_inputs( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) - self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_( - self.input_batch.block_table[0]. - slot_mapping_cpu[:total_num_scheduled_tokens]) self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -1131,32 +1226,7 @@ def _prepare_inputs( self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) - - # Make AscendCommonAttentionMetadata - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens_cpu, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - actual_seq_lengths_q=self.actual_seq_lengths_q, - block_table_tensor=self.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping_cpu=self.slot_mapping_cpu, - positions=self.positions, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - attn_state=self.attn_state, - enable_dbo_across_dp=enable_dbo, - is_only_prefill=bool(np.all(num_valid_tokens != 1)), - max_query_len=max_num_scheduled_tokens, - graph_pad_size=self.graph_pad_size, - decode_token_per_req=self.decode_token_per_req, - ) - attn_metadata = self.attn_metadata_builder.build( - common_attn_metadata, self.model) - if self.vllm_config.model_config.use_mla: - attn_metadata.num_input_tokens = num_input_tokens + attn_metadata: dict[str, Any] = {} # Prepare input_ids token_indices = (positions_np + @@ -1238,6 +1308,90 @@ def _prepare_inputs( spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[: + total_num_scheduled_tokens] + self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_( + slot_mapping) + # # Fill unused with -1. Needed for reshape_and_cache in full cuda + # # graph mode. + # blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping_cpu=self.slot_mapping_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.metadata_builder + if isinstance(builder, GDNAttentionMetadataBuilder): + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens. + gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.model, + **extra_attn_metadata_args) + + if self.vllm_config.model_config.use_mla: + attn_metadata_i.num_input_tokens = num_input_tokens + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i if lmhead_tp_enable(): max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs @@ -1453,9 +1607,7 @@ def propose_draft_token_ids( positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, - attn_metadata: Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata, - AscendMLATorchairMetadata], + attn_metadata: dict[str, Any], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.drafter: @@ -1700,6 +1852,7 @@ def execute_model( sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) discard_sampled_tokens_req_indices: list[int] = [] # TODO(woosuk): The following loop can be slow since it iterates over @@ -2231,31 +2384,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - kv_caches: Dict[str, torch.Tensor] = {} + self.may_reinitialize_input_batch(kv_cache_config) + self.initialize_attn_backend(kv_cache_config) - def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: - data_ptr = tensor.data_ptr() - aligned_addr = (data_ptr + alignment - 1) // alignment * alignment - offset = (aligned_addr - data_ptr) // tensor.element_size() - return tensor[int(offset):] + if self.model_config.is_deepseek_mla: + kv_caches = self.initialize_kv_cache_tensors_deepseek( + kv_cache_config) + else: + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.block_size], - is_spec_decode=bool(self.vllm_config.speculative_config), - logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, - self.is_pooling_model, - self.vllm_config.model_config.logits_processors), - is_pooling_model=self.is_pooling_model, - ) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def initialize_kv_cache_tensors_deepseek( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( @@ -2263,12 +2407,131 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: "NPU.") kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec + def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + + kv_caches: Dict[str, torch.Tensor] = {} + for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator( + ): + attn_backend = kv_cache_group.backend for layer_name in kv_cache_group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue tensor_size = kv_cache_sizes[layer_name] - assert tensor_size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # init kv cache tensors + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + for idx in range(len(kv_cache_tensor.shared_by)): + # TODO: REFACTOR ME to sharing hybrid cache + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + layer_name = kv_cache_tensor.shared_by[idx] + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" + + kv_caches: Dict[str, torch.Tensor] = {} + for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator( + ): + attn_backend = kv_cache_group.backend + for layer_name in kv_cache_group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + raw_tensor = kv_cache_raw_tensors[layer_name] + + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that @@ -2278,100 +2541,228 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks - alignment = 2 * 1024 * 1024 + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): if self.vllm_config.additional_config.get( "kv_cache_dtype", None) == 'int8': - kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape( + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.model_config.is_deepseek_mla: - num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape - rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - nope_dim = head_size - rope_dim - nope_cache_shape = (num_blocks, block_size, - num_kv_heads, nope_dim) - rope_cache_shape = (num_blocks, block_size, - num_kv_heads, rope_dim) - if self.vllm_config.kv_transfer_config is None: - # For no disaggregate pd scenario, allocate kv cache in normal way - rope_cache = torch.zeros(rope_cache_shape, - dtype=dtype, - device=self.device) - nope_cache = torch.zeros(nope_cache_shape, - dtype=dtype, - device=self.device) - rope_cache = self._convert_torch_format(rope_cache) - nope_cache = self._convert_torch_format(nope_cache) - else: - - # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory - # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but - # we found there are also some exceptions during test, so we manual align those memory here, this part - # of code may consume 2M * 2 * elem_size memory every layer. - nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim - nope_allocate_shape_alignment = nope_allocate_shape + alignment - rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim - rope_allocate_shape_alignment = rope_allocate_shape + alignment - - nope_cache = torch.zeros( - nope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - rope_cache = torch.zeros( - rope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - nope_cache = align_memory( - nope_cache, - alignment)[:nope_allocate_shape].view( - nope_cache_shape) - rope_cache = align_memory( - rope_cache, - alignment)[:rope_allocate_shape].view( - rope_cache_shape) - kv_caches[layer_name] = (nope_cache, rope_cache) - else: - num_caches = kv_cache_shape[0] - kv_cache_list = [] - for i in range(num_caches): - cache_shape = kv_cache_shape[1:] - if self.vllm_config.kv_transfer_config is None: - kv_cache = torch.zeros(cache_shape, - dtype=dtype, - device=self.device) - kv_cache = self._convert_torch_format(kv_cache) - else: - cache_size = math.prod(cache_shape) - cache_size_aligned = cache_size + alignment - kv_cache = torch.zeros(cache_size_aligned, - dtype=dtype, - device=self.device) - kv_cache = align_memory( - kv_cache, - alignment)[:cache_size].view(cache_shape) - kv_cache_list.append(kv_cache) - kv_caches[layer_name] = tuple(kv_cache_list) + kv_cache = raw_tensor.view(dtype).view(kv_cache_shape) + kv_cache = self._convert_torch_format(kv_cache) + kv_caches[layer_name] = kv_cache + elif isinstance(kv_cache_spec, MambaSpec): + raw_tensor = kv_cache_raw_tensors[layer_name] + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + kv_caches[layer_name] = state_tensors else: - # TODO: add new branches when introducing more types of - # KV cache specs. raise ValueError("Unknown KV cache spec type.") bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) - if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) + return kv_caches + + def _kv_cache_spec_attn_group_iterator( + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + if not self.kv_cache_config.kv_cache_groups: + return + for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): + for attn_group in attn_groups: + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group + + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + + # Generate kernel_block_sizes that matches each block_size + # For attention backends that support virtual block splitting, + # use the supported block sizes from the backend + # For other backends (like Mamba), use [0] (no splitting) + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # the backend. + try: + attn_groups = self.attn_groups[kv_cache_group_id] + except IndexError: + attn_groups = None + if attn_groups: + # Use the backend's supported block size list + backend = attn_groups[0].backend + supported_sizes = backend.get_supported_block_size() + # If no specific sizes supported, use cache config + # block_size + kernel_block_size_list = (supported_sizes + if supported_sizes else + [self.cache_config.block_size]) + else: + # Fallback to cache config block_size if no backend found + kernel_block_size_list = [ + 64 + ] if not self.model_config.is_deepseek_mla else [0] + kernel_block_sizes.append(kernel_block_size_list) + else: + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append([0]) + if kernel_block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + kernel_block_sizes=kernel_block_sizes, + ) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + def get_attn_backends_for_layers( + layer_names: list[str] + ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in layer_names: + attn_backend = layers[layer_name].get_attn_backend() + key = attn_backend.full_cls_name() + attn_backends[key] = attn_backend + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + kv_cache_spec: KVCacheSpec, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for attn_backend, layer_names in attn_backends_map.items(): + attn_metadata_builder_i = attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builder_i, + layer_names) + attn_groups.append(attn_group) + return attn_groups + + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) + self.attn_groups.append( + create_attn_groups(attn_backends, kv_cache_spec)) + + # Calculate reorder batch threshold (if needed) + self.calculate_reorder_batch_threshold() + + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def calculate_reorder_batch_threshold(self) -> None: + """ + Check that if any backends reorder batches; that the reordering + is compatible (e.g., decode threshold is the same) + """ + for group in self._attn_group_iterator(): + attn_metadata_builder_i = group.metadata_builder + if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"): + # check that if any backends reorder batches; that the reordering + # is compatible (e.g., decode threshold is the same) + reorder_batch_threshold_i = ( + attn_metadata_builder_i.reorder_batch_threshold) + if reorder_batch_threshold_i is not None: + if self.reorder_batch_threshold is not None: + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: + raise ValueError( + f"Attention backend reorders decodes with " + f"threshold {reorder_batch_threshold_i} but other " + f"backend uses threshold " + f"{self.reorder_batch_threshold}") + else: + self.reorder_batch_threshold = reorder_batch_threshold_i def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -2382,19 +2773,29 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=self.block_size, + block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, @@ -2409,6 +2810,35 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) + if len(mamba_layers) > 0: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) + return kv_cache_spec def initialize_aclgraph_capture(self) -> None: diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 42eddcce08..ce37ff921c 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -37,7 +37,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import MultiGroupBlockTable + +from vllm_ascend.worker.block_table import MultiGroupBlockTable @dataclass @@ -85,18 +86,19 @@ def get_token_id(self, idx: int) -> int: class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group - logitsprocs: Optional[LogitsProcessors] = None, - is_spec_decode: bool = False, - is_pooling_model: bool = False, - ): + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, + is_spec_decode: bool = False, + is_pooling_model: bool = False, + num_speculative_tokens: int = 0, + kernel_block_sizes: Optional[list[list[int]]] = None): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -140,7 +142,8 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, - ) + num_speculative_tokens=num_speculative_tokens, + kernel_sizes=kernel_block_sizes) # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), @@ -215,6 +218,14 @@ def __init__( self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.num_accepted_tokens_cpu = \ + self.num_accepted_tokens_cpu_tensor.numpy() + # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) @@ -409,6 +420,9 @@ def add_request( else: raise NotImplementedError(request) + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -508,6 +522,8 @@ def swap_states(self, i1: int, i2: int) -> None: self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ + self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -614,6 +630,8 @@ def condense(self) -> None: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] + self.num_accepted_tokens_cpu[ + empty_index] = self.num_accepted_tokens_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator