-
Notifications
You must be signed in to change notification settings - Fork 465
add mla_preprocess kernel #3226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new mla_preprocess
kernel, which involves significant changes across the host-side tiling logic, kernel implementation, and PyTorch bindings. The implementation is complex and adds a substantial amount of new code. My review focuses on critical aspects of correctness, performance, and thread safety. I've identified a critical race condition in the host-side tiling logic that needs to be addressed to prevent data corruption in multi-threaded environments. Additionally, there are opportunities to improve kernel performance by optimizing memory copy operations and to enhance correctness and performance on the host by replacing floating-point calculations with integer arithmetic for tiling parameters.
static auto global_tiling_data = | ||
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS}, | ||
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device())); | ||
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) { | ||
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize, | ||
ACL_MEMCPY_HOST_TO_DEVICE); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of a static
local variable global_tiling_data
to store tiling information creates a potential race condition. If multiple threads call mla_preprocess_tiling
concurrently with the same number of tokens N
, they will calculate the same bIndex
and attempt to write to the same memory location in global_tiling_data
simultaneously. This can lead to corrupted data and unpredictable kernel behavior. To ensure thread safety, this buffer should be managed in a thread-safe manner, for example by allocating it as part of a workspace that is unique per call, rather than using a shared static buffer.
void PpMatmulTilingApi::GetTileSize() | ||
{ | ||
bool priFlag = !(m_ < n_); | ||
uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of roundBase
uses floating-point functions (pow
, ceil
, log
) for what appears to be an integer calculation to find the next power of two. This approach can be slow and may suffer from floating-point precision issues, potentially leading to incorrect tiling parameters. It is recommended to replace this with integer-based bit manipulation for both performance and correctness. The use of log
(natural logarithm) is also likely incorrect if the intent is to work with powers of two; log2
would be the mathematical function to use, but bit manipulation avoids this entirely.
For example, to find the smallest power of two greater than or equal to an integer x
, you can use a bit-twiddling algorithm:
// Example of a bit-twiddling algorithm to find the next power of 2
uint32_t next_power_of_2(uint32_t n) {
if (n == 0) return 1;
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n++;
return n;
}
// The original line could be refactored to use such a helper function.
uint32_t val = CeilDiv(priFlag ? n_ : m_, CONST_16);
uint32_t next_pow2_val = next_power_of_2(val);
uint32_t roundBase = next_pow2_val * CONST_16;
mlaTilingData.tilingKey = tilingData->tilingKey; | ||
mlaTilingData.n = tilingData->n; | ||
|
||
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch; | ||
mlaTilingData.mm1.m = tilingData->mm1.m; | ||
mlaTilingData.mm1.k = tilingData->mm1.k; | ||
mlaTilingData.mm1.n = tilingData->mm1.n; | ||
mlaTilingData.mm1.m0 = tilingData->mm1.m0; | ||
mlaTilingData.mm1.k0 = tilingData->mm1.k0; | ||
mlaTilingData.mm1.n0 = tilingData->mm1.n0; | ||
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop; | ||
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop; | ||
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop; | ||
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop; | ||
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount; | ||
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK; | ||
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim; | ||
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat; | ||
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen; | ||
|
||
mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch; | ||
mlaTilingData.mm2.m = tilingData->mm2.m; | ||
mlaTilingData.mm2.k = tilingData->mm2.k; | ||
mlaTilingData.mm2.n = tilingData->mm2.n; | ||
mlaTilingData.mm2.m0 = tilingData->mm2.m0; | ||
mlaTilingData.mm2.k0 = tilingData->mm2.k0; | ||
mlaTilingData.mm2.n0 = tilingData->mm2.n0; | ||
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop; | ||
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop; | ||
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop; | ||
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop; | ||
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount; | ||
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK; | ||
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim; | ||
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat; | ||
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen; | ||
|
||
mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch; | ||
mlaTilingData.mm3.m = tilingData->mm3.m; | ||
mlaTilingData.mm3.k = tilingData->mm3.k; | ||
mlaTilingData.mm3.n = tilingData->mm3.n; | ||
mlaTilingData.mm3.m0 = tilingData->mm3.m0; | ||
mlaTilingData.mm3.k0 = tilingData->mm3.k0; | ||
mlaTilingData.mm3.n0 = tilingData->mm3.n0; | ||
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop; | ||
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop; | ||
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop; | ||
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop; | ||
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount; | ||
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK; | ||
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim; | ||
|
||
mlaTilingData.perTaskNum = tilingData->perTaskNum; | ||
mlaTilingData.resTaskNum = tilingData->resTaskNum; | ||
mlaTilingData.numCore = tilingData->numCore; | ||
|
||
mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1; | ||
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1; | ||
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2; | ||
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2; | ||
|
||
mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ; | ||
mlaTilingData.headNumQ = tilingData->headNumQ; | ||
mlaTilingData.headDim = tilingData->headDim; | ||
mlaTilingData.concatSize = tilingData->concatSize; | ||
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff; | ||
mlaTilingData.ntokens = tilingData->ntokens; | ||
mlaTilingData.realCore = tilingData->realCore; | ||
mlaTilingData.nlCoreRun = tilingData->nlCoreRun; | ||
mlaTilingData.lCoreRun = tilingData->lCoreRun; | ||
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb; | ||
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime; | ||
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast; | ||
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime; | ||
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast; | ||
|
||
mlaTilingData.esqFrontCore = tilingData->esqFrontCore; | ||
mlaTilingData.esqTailCore = tilingData->esqTailCore; | ||
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch; | ||
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch; | ||
mlaTilingData.esqHeadNum = tilingData->esqHeadNum; | ||
mlaTilingData.esqColNum = tilingData->esqColNum; | ||
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop; | ||
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop; | ||
mlaTilingData.esqHeadTail = tilingData->esqHeadTail; | ||
mlaTilingData.esqColLoop = tilingData->esqColLoop; | ||
mlaTilingData.esqColTail = tilingData->esqColTail; | ||
|
||
mlaTilingData.s1Offset = tilingData->s1Offset; | ||
mlaTilingData.s2Offset = tilingData->s2Offset; | ||
mlaTilingData.s3Offset = tilingData->s3Offset; | ||
mlaTilingData.s4Offset = tilingData->s4Offset; | ||
mlaTilingData.s5Offset = tilingData->s5Offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel copies the MlaTilingData
structure from global memory to the stack field by field. This results in numerous small, individual memory accesses from global memory, which is highly inefficient and can significantly degrade kernel performance. A single bulk memory copy of the entire MlaTilingData
structure into a UB buffer, and then a memcpy
to the stack variable (or direct use from UB), would be much more efficient.
Example of a more efficient copy:
MlaTilingData mlaTilingData;
// Allocate a temporary buffer in UB
AscendC::TBuf<AscendC::TPosition::VECCALC> tilingUB;
AscendC::TPipe pipe;
pipe.InitBuffer(tilingUB, sizeof(MlaTilingData));
// Copy from GM to UB
AscendC::DataCopy(tilingUB.Get<uint8_t>(), reinterpret_cast<__gm__ uint8_t*>(tiling), sizeof(MlaTilingData));
pipe.Barrier();
// Copy from UB to stack
memcpy(&mlaTilingData, tilingUB.Get<uint8_t>().GetPhyAddr(), sizeof(MlaTilingData));
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr, | ||
tiling_ptr, block_dim]() -> int { | ||
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr, | ||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mla_preprocess_impl
function is called with sin_ptr
and cos_ptr
passed for both (sin1, cos1)
and (sin2, cos2)
arguments. This suggests a potential mismatch between the caller's intent and the kernel's expectation, or that the kernel signature is redundant. If sin2
and cos2
are intended to be different from sin1
and cos1
, this is a bug. If they are always the same, the kernel signature and the host-side function should be simplified to accept only one set of sin/cos tensors to avoid confusion and potential future errors.
Signed-off-by: mojave2 <chenchen145@huawei.com>
What this PR does / why we need it?
mla_preprocess
custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs.Does this PR introduce any user-facing change?
How was this patch tested?
Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path.
vLLM version: v0.10.2
vLLM main: vllm-project/vllm@releases/v0.11.0