diff --git a/CMakeLists.txt b/CMakeLists.txt index 005590445361..9c7d492af2ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/cuda_view.cu" diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu new file mode 100644 index 000000000000..e3cee16de10a --- /dev/null +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#define CHECK_TYPE(x, st) \ + TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ + ", while ", st, " is expected") +#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x, st) \ + CHECK_TH_CUDA(x); \ + CHECK_CONTIGUOUS(x); \ + CHECK_TYPE(x, st) + +#define FINAL_MASK 0xffffffff + +namespace tensorrt_llm::common { +template +struct packed_as; +// Specialization for packed_as used in this kernel. +template <> +struct packed_as { + using type = uint; +}; + +template <> +struct packed_as { + using type = uint2; +}; + +template <> +struct packed_as { + using type = uint4; +}; + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, + 32); //__shfl_sync bf16 return float when sm < 80 + return val; +} + +template +inline __device__ __host__ T divUp(T m, T n) { + return (m + n - 1) / n; +} + +} // namespace tensorrt_llm::common + +namespace tensorrt_llm::kernels { +// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation, +// with added support for passing the cos_sin_cache as an input. +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu + +// Perform per-head QK Norm and RoPE in a single kernel. +// head_dim: the dimension of each head +// interleave: interleave=!is_neox. +template +__global__ void fusedQKNormRopeKernel( + __nv_bfloat16* qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + float const eps, // Epsilon for RMS normalization + __nv_bfloat16 const* q_weight, // RMSNorm weights for query + __nv_bfloat16 const* k_weight, // RMSNorm weights for key + __nv_bfloat16 const* cos_sin_cache, // Pre-computed cos/sin cache + int64_t const* position_ids, // Position IDs for RoPE + int const num_tokens // Number of tokens +) { + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + // Calculate global warp index to determine which head/token this warp + // processes + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + // Total number of attention heads (Q and K) + int const total_qk_heads = num_heads_q + num_heads_k; + + // Determine which token and head type (Q or K) this warp processes + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + // Skip if this warp is assigned beyond the number of tokens + if (tokenIdx >= num_tokens) return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one " + "head, and each thread gets even number of " + "elements)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, "numSizeBytes must be a multiple of 4"); + constexpr int vecSize = + elemSizeBytes / + 4; // Use packed_as to perform loading/saving. + using vec_T = typename tensorrt_llm::common::packed_as::type; + + int offsetWarp; // Offset for the warp + if (isQ) { + // Q segment: token offset + head offset within Q segment + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } else { + // K segment: token offset + entire Q segment + head offset within K segment + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // Sum of squares for RMSNorm + float sumOfSquares = 0.0f; + + // Load. + { + vec_T vec = *reinterpret_cast(&qkv[offsetThread]); + for (int i = 0; i < vecSize; i++) { + float2 vals = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>( + reinterpret_cast(&vec) + i)); + sumOfSquares += vals.x * vals.x; + sumOfSquares += vals.y * vals.y; + + elements[2 * i] = vals.x; + elements[2 * i + 1] = vals.y; + } + } + + // Reduce sum across warp using the utility function + sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares); + + // Compute RMS normalization factor + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + + // Normalize elements + for (int i = 0; i < numElemsPerThread; i++) { + int dim = laneId * numElemsPerThread + i; + float weight = + isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // Apply RoPE to normalized elements + float elements2[numElemsPerThread]; // Additional buffer required for RoPE. + + int64_t pos_id = position_ids[tokenIdx]; + + // Calculate cache pointer for this position - similar to + // pos_encoding_kernels.cu + __nv_bfloat16 const* cache_ptr = cos_sin_cache + pos_id * head_dim; + int const embed_dim = head_dim / 2; + __nv_bfloat16 const* cos_ptr = cache_ptr; + __nv_bfloat16 const* sin_ptr = cache_ptr + embed_dim; + + if constexpr (interleave) { + // Perform interleaving. Use pre-computed cos/sin values. + for (int i = 0; i < numElemsPerThread; i++) { + if (i % 2 == 0) { + elements2[i] = -elements[i + 1]; + } else { + elements2[i] = elements[i - 1]; + } + + int dim_idx = laneId * numElemsPerThread + i; + int half_dim = dim_idx / 2; + // Use pre-computed cos/sin from cache with optimized memory access + float cos_val = __bfloat162float(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = __bfloat162float(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + } else { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + // Get the data from the other half of the warp. Use pre-computed cos/sin + // values. + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + if (laneId < 16) { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + dim_idx = (dim_idx * 2) % head_dim; + int half_dim = dim_idx / 2; + // Use pre-computed cos/sin from cache with optimized memory access + float cos_val = __bfloat162float(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = __bfloat162float(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); + } + + // Store. + { + vec_T vec; + for (int i = 0; i < vecSize; i++) { + __nv_bfloat162 vals = __float22bfloat162_rn( + make_float2(elements[2 * i], elements[2 * i + 1])); + reinterpret_cast<__nv_bfloat162&>(*(reinterpret_cast(&vec) + i)) = + vals; + } + vec_T* outputPtr = reinterpret_cast(&qkv[offsetThread]); + *outputPtr = vec; + } +} + +// Borrowed from +// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +void launchFusedQKNormRope(void* qkv, int const num_tokens, + int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, + float const eps, void const* q_weight, + void const* k_weight, + __nv_bfloat16 const* cos_sin_cache, + bool const interleave, int64_t const* position_ids, + cudaStream_t stream) { + constexpr int blockSize = 256; + + int const warpsPerBlock = blockSize / 32; + int const totalQKHeads = num_heads_q + num_heads_k; + int const totalWarps = num_tokens * totalQKHeads; + + int const gridSize = common::divUp(totalWarps, warpsPerBlock); + dim3 gridDim(gridSize); + dim3 blockDim(blockSize); + + // Head dimensions should be a multiple of 64 + // Add more cases as needed + switch (head_dim) { + case 64: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel<64, INTERLEAVE><<>>( + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, + num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), + reinterpret_cast<__nv_bfloat16 const*>(k_weight), cos_sin_cache, + position_ids, num_tokens); + }); + break; + case 128: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel<128, INTERLEAVE> + <<>>( + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, + num_heads_v, eps, + reinterpret_cast<__nv_bfloat16 const*>(q_weight), + reinterpret_cast<__nv_bfloat16 const*>(k_weight), cos_sin_cache, + position_ids, num_tokens); + }); + break; + default: + TORCH_CHECK(false, + "Unsupported head dimension for fusedQKNormRope: ", head_dim); + } +} +} // namespace tensorrt_llm::kernels + +void fused_qk_norm_rope( + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] +) { + // Input validation + TORCH_CHECK(qkv.dim() == 2, + "QKV tensor must be 2D: [num_tokens, " + "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); + TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + TORCH_CHECK(cos_sin_cache.dim() == 2, + "Cos/sin cache must be 2D: [max_position, head_dim]"); + TORCH_CHECK(q_weight.size(0) == head_dim, + "Query weights size must match head dimension"); + TORCH_CHECK(k_weight.size(0) == head_dim, + "Key weights size must match head dimension"); + TORCH_CHECK(cos_sin_cache.size(1) == head_dim, + "Cos/sin cache dimension must match head_dim"); + + CHECK_INPUT(qkv, torch::kBFloat16); + CHECK_INPUT(position_ids, torch::kInt64); + CHECK_INPUT(q_weight, torch::kBFloat16); + CHECK_INPUT(k_weight, torch::kBFloat16); + CHECK_INPUT(cos_sin_cache, torch::kBFloat16); + + int64_t num_tokens = qkv.size(0); + TORCH_CHECK(position_ids.size(0) == num_tokens, + "Number of tokens in position_ids must match QKV"); + + int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; + TORCH_CHECK( + qkv.size(1) == total_heads * head_dim, + "QKV tensor size must match total number of heads and head dimension"); + + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + + tensorrt_llm::kernels::launchFusedQKNormRope( + reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()), + static_cast(num_tokens), static_cast(num_heads_q), + static_cast(num_heads_k), static_cast(num_heads_v), + static_cast(head_dim), static_cast(eps), + reinterpret_cast<__nv_bfloat16 const*>(q_weight.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(k_weight.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(cos_sin_cache.data_ptr()), + !is_neox, // interleave + reinterpret_cast(position_ids.data_ptr()), stream); +} diff --git a/csrc/ops.h b/csrc/ops.h index c135a1404294..495364c5058f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, + int64_t num_heads_k, int64_t num_heads_v, + int64_t head_dim, double eps, torch::Tensor& q_weight, + torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, + bool is_neox, torch::Tensor& position_ids); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2bc526097d15..cfe3658be809 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,6 +175,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Function for fused QK Norm and RoPE + ops.def( + "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " + "int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " + "bool is_neox, Tensor position_ids) -> ()"); + ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0618451c199a..a3139ba070c9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -339,6 +339,34 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + torch.ops._C.fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + position_ids, + ) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index 3cf3444e20db..e6405a22418c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -188,6 +188,7 @@ VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False + VLLM_FUSE_QKNORM_AND_ROPE: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False @@ -1313,6 +1314,10 @@ def get_vllm_port() -> int | None: None, ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"], ), + # If set, use the fuse QKNorm and RoPE kernel + "VLLM_FUSE_QKNORM_AND_ROPE": lambda: bool( + int(os.getenv("VLLM_FUSE_QKNORM_AND_ROPE", "0")) + ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 563d3cc23d72..fbd1fb56f72e 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -30,6 +30,8 @@ from torch import nn from transformers import Qwen3Config +from vllm import _custom_ops as ops +from vllm import envs from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -39,7 +41,7 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.sequence import IntermediateTensors @@ -117,6 +119,13 @@ def __init__( rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) + # Determine if we can use fused QK norm + RoPE + self.use_fused_qk_norm_rope = envs.VLLM_FUSE_QKNORM_AND_ROPE and isinstance( + self.rotary_emb, RotaryEmbedding + ) + if self.use_fused_qk_norm_rope: + logger.info_once("Using fused QK norm + RoPE kernel for Qwen3Attention") + self.attn = Attention( self.num_heads, self.head_dim, @@ -136,21 +145,50 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + def apply_qk_norm_rope(self, qkv, positions): + if self.use_fused_qk_norm_rope: + ops.fused_qk_norm_rope( + qkv, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.q_norm.variance_epsilon, + self.q_norm.weight, + self.k_norm.weight, + self.rotary_emb.cos_sin_cache, + self.rotary_emb.is_neox_style, + positions.view(-1), + ) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + else: + # Fallback to non-fused QK Norm & RoPE implementation + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view( + *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim + ) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view( + *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim + ) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + + if q.size(0) > 0 and k.size(0) > 0: + q, k = self.rotary_emb(positions, q, k) + + return q, k, v + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) + q, k, v = self.apply_qk_norm_rope(qkv, positions) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 8452d7b04f5c..d9fda5cf6e03 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -31,6 +31,8 @@ import torch from torch import nn +import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config @@ -52,7 +54,7 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -275,6 +277,13 @@ def __init__( rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) + # Determine if we can use fused QK norm + RoPE + self.use_fused_qk_norm_rope = envs.VLLM_FUSE_QKNORM_AND_ROPE and isinstance( + self.rotary_emb, RotaryEmbedding + ) + if self.use_fused_qk_norm_rope: + logger.info_once("Using fused QK norm + RoPE kernel for Qwen3MoeAttention") + self.attn = Attention( self.num_heads, self.head_dim, @@ -294,22 +303,48 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + def apply_qk_norm_rope(self, qkv, positions): + if self.use_fused_qk_norm_rope: + ops.fused_qk_norm_rope( + qkv, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.q_norm.variance_epsilon, + self.q_norm.weight, + self.k_norm.weight, + self.rotary_emb.cos_sin_cache, + self.rotary_emb.is_neox_style, + positions.view(-1), + ) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + else: + # Fallback to non-fused QK Norm & RoPE implementation + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view( + *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim + ) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view( + *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim + ) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + + q, k = self.rotary_emb(positions, q, k) + + return q, k, v + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k) + q, k, v = self.apply_qk_norm_rope(qkv, positions) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output