Skip to content

Conversation

kiscad
Copy link

@kiscad kiscad commented Sep 28, 2025

What this PR does / why we need it?

  • Adds the mla_preprocess custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs.
  • Wires the new kernel into the C++ extension pipeline so vLLM can invoke it directly, cutting Python-side tensor shuffling and memory copies that previously bottlenecked MLA compilation paths.

Does this PR introduce any user-facing change?

  • No. The change only introduces a low-level kernel; public APIs and inference behavior remain unchanged.

How was this patch tested?

  • Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path.

  • vLLM version: v0.10.2

  • vLLM main: vllm-project/vllm@releases/v0.11.0

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new mla_preprocess kernel, which involves significant changes across the host-side tiling logic, kernel implementation, and PyTorch bindings. The implementation is complex and adds a substantial amount of new code. My review focuses on critical aspects of correctness, performance, and thread safety. I've identified a critical race condition in the host-side tiling logic that needs to be addressed to prevent data corruption in multi-threaded environments. Additionally, there are opportunities to improve kernel performance by optimizing memory copy operations and to enhance correctness and performance on the host by replacing floating-point calculations with integer arithmetic for tiling parameters.

Comment on lines +680 to +686
static auto global_tiling_data =
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS},
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()));
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize,
ACL_MEMCPY_HOST_TO_DEVICE);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a static local variable global_tiling_data to store tiling information creates a potential race condition. If multiple threads call mla_preprocess_tiling concurrently with the same number of tokens N, they will calculate the same bIndex and attempt to write to the same memory location in global_tiling_data simultaneously. This can lead to corrupted data and unpredictable kernel behavior. To ensure thread safety, this buffer should be managed in a thread-safe manner, for example by allocating it as part of a workspace that is unique per call, rather than using a shared static buffer.

void PpMatmulTilingApi::GetTileSize()
{
bool priFlag = !(m_ < n_);
uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of roundBase uses floating-point functions (pow, ceil, log) for what appears to be an integer calculation to find the next power of two. This approach can be slow and may suffer from floating-point precision issues, potentially leading to incorrect tiling parameters. It is recommended to replace this with integer-based bit manipulation for both performance and correctness. The use of log (natural logarithm) is also likely incorrect if the intent is to work with powers of two; log2 would be the mathematical function to use, but bit manipulation avoids this entirely.

For example, to find the smallest power of two greater than or equal to an integer x, you can use a bit-twiddling algorithm:

// Example of a bit-twiddling algorithm to find the next power of 2
uint32_t next_power_of_2(uint32_t n) {
    if (n == 0) return 1;
    n--;
    n |= n >> 1;
    n |= n >> 2;
    n |= n >> 4;
    n |= n >> 8;
    n |= n >> 16;
    n++;
    return n;
}

// The original line could be refactored to use such a helper function.
uint32_t val = CeilDiv(priFlag ? n_ : m_, CONST_16);
uint32_t next_pow2_val = next_power_of_2(val);
uint32_t roundBase = next_pow2_val * CONST_16;

Comment on lines +40 to +132
mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n;

mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m;
mlaTilingData.mm1.k = tilingData->mm1.k;
mlaTilingData.mm1.n = tilingData->mm1.n;
mlaTilingData.mm1.m0 = tilingData->mm1.m0;
mlaTilingData.mm1.k0 = tilingData->mm1.k0;
mlaTilingData.mm1.n0 = tilingData->mm1.n0;
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop;
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop;
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop;
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop;
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount;
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK;
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim;
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat;
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen;

mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch;
mlaTilingData.mm2.m = tilingData->mm2.m;
mlaTilingData.mm2.k = tilingData->mm2.k;
mlaTilingData.mm2.n = tilingData->mm2.n;
mlaTilingData.mm2.m0 = tilingData->mm2.m0;
mlaTilingData.mm2.k0 = tilingData->mm2.k0;
mlaTilingData.mm2.n0 = tilingData->mm2.n0;
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop;
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop;
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop;
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop;
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount;
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK;
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim;
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat;
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen;

mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch;
mlaTilingData.mm3.m = tilingData->mm3.m;
mlaTilingData.mm3.k = tilingData->mm3.k;
mlaTilingData.mm3.n = tilingData->mm3.n;
mlaTilingData.mm3.m0 = tilingData->mm3.m0;
mlaTilingData.mm3.k0 = tilingData->mm3.k0;
mlaTilingData.mm3.n0 = tilingData->mm3.n0;
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop;
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop;
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop;
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop;
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount;
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK;
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim;

mlaTilingData.perTaskNum = tilingData->perTaskNum;
mlaTilingData.resTaskNum = tilingData->resTaskNum;
mlaTilingData.numCore = tilingData->numCore;

mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1;
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1;
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2;
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2;

mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ;
mlaTilingData.headNumQ = tilingData->headNumQ;
mlaTilingData.headDim = tilingData->headDim;
mlaTilingData.concatSize = tilingData->concatSize;
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff;
mlaTilingData.ntokens = tilingData->ntokens;
mlaTilingData.realCore = tilingData->realCore;
mlaTilingData.nlCoreRun = tilingData->nlCoreRun;
mlaTilingData.lCoreRun = tilingData->lCoreRun;
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb;
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime;
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast;
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime;
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast;

mlaTilingData.esqFrontCore = tilingData->esqFrontCore;
mlaTilingData.esqTailCore = tilingData->esqTailCore;
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch;
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch;
mlaTilingData.esqHeadNum = tilingData->esqHeadNum;
mlaTilingData.esqColNum = tilingData->esqColNum;
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop;
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop;
mlaTilingData.esqHeadTail = tilingData->esqHeadTail;
mlaTilingData.esqColLoop = tilingData->esqColLoop;
mlaTilingData.esqColTail = tilingData->esqColTail;

mlaTilingData.s1Offset = tilingData->s1Offset;
mlaTilingData.s2Offset = tilingData->s2Offset;
mlaTilingData.s3Offset = tilingData->s3Offset;
mlaTilingData.s4Offset = tilingData->s4Offset;
mlaTilingData.s5Offset = tilingData->s5Offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel copies the MlaTilingData structure from global memory to the stack field by field. This results in numerous small, individual memory accesses from global memory, which is highly inefficient and can significantly degrade kernel performance. A single bulk memory copy of the entire MlaTilingData structure into a UB buffer, and then a memcpy to the stack variable (or direct use from UB), would be much more efficient.

Example of a more efficient copy:

MlaTilingData mlaTilingData;
// Allocate a temporary buffer in UB
AscendC::TBuf<AscendC::TPosition::VECCALC> tilingUB;
AscendC::TPipe pipe;
pipe.InitBuffer(tilingUB, sizeof(MlaTilingData));

// Copy from GM to UB
AscendC::DataCopy(tilingUB.Get<uint8_t>(), reinterpret_cast<__gm__ uint8_t*>(tiling), sizeof(MlaTilingData));
pipe.Barrier();

// Copy from UB to stack
memcpy(&mlaTilingData, tilingUB.Get<uint8_t>().GetPhyAddr(), sizeof(MlaTilingData));

qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
tiling_ptr, block_dim]() -> int {
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mla_preprocess_impl function is called with sin_ptr and cos_ptr passed for both (sin1, cos1) and (sin2, cos2) arguments. This suggests a potential mismatch between the caller's intent and the kernel's expectation, or that the kernel signature is redundant. If sin2 and cos2 are intended to be different from sin1 and cos1, this is a bug. If they are always the same, the kernel signature and the host-side function should be simplified to accept only one set of sin/cos tensors to avoid confusion and potential future errors.

Signed-off-by: mojave2 <chenchen145@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dist-test ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants