Skip to content

Conversation

@hust17yixuan
Copy link
Contributor

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: hust17yixuan <303660421@qq.com>
@github-actions
Copy link

github-actions bot commented Nov 5, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new batch_matmul_transpose custom operator for Ascend NPUs to optimize a specific batched matrix multiplication pattern. The changes span from the low-level kernel implementation and host-side tiling logic to the PyTorch bindings and its usage in the mla_v1.py attention implementation. While the overall implementation seems to correctly integrate the new operator, I've identified a critical thread-safety issue in the host-side C++ code that could lead to race conditions and incorrect results. My review includes a specific code suggestion to resolve this critical bug.

Comment on lines +104 to +118
int32_t batchIdx = opShape.m - 1;
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
static auto global_tiling_data = at::empty(
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
} else {
// Handle the case where batchIdx is out of range
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
}
at::Tensor tiling_tensor =
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);

return std::make_tuple(tiling_tensor, block_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a static local variable global_tiling_data to cache tiling information on the device is not thread-safe. If multiple threads execute this function concurrently with the same opShape.m, they will write to the same memory location, causing a race condition. This can lead to incorrect tiling data being used by the kernel, resulting in wrong computations or crashes.

Given that PpMatmulTilingData is a small struct, it's safer and cleaner to allocate a new tensor for it on each call. This avoids the race condition and also removes the limitation of MAX_CAPTURE_NUM on the m dimension.

    uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
    auto tiling_tensor = at::empty({(int64_t)tilingSize}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
    aclrtMemcpy(tiling_tensor.data_ptr<uint8_t>(), tilingSize, &matmulTilingData,
                tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);

    return std::make_tuple(tiling_tensor, block_dim);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant