-
Notifications
You must be signed in to change notification settings - Fork 553
[Cherry-pick]bmm_transpose to v011dev #3995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.11.0-dev
Are you sure you want to change the base?
[Cherry-pick]bmm_transpose to v011dev #3995
Conversation
Signed-off-by: hust17yixuan <303660421@qq.com>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new batch_matmul_transpose custom operator for Ascend NPUs to optimize a specific batched matrix multiplication pattern. The changes span from the low-level kernel implementation and host-side tiling logic to the PyTorch bindings and its usage in the mla_v1.py attention implementation. While the overall implementation seems to correctly integrate the new operator, I've identified a critical thread-safety issue in the host-side C++ code that could lead to race conditions and incorrect results. My review includes a specific code suggestion to resolve this critical bug.
| int32_t batchIdx = opShape.m - 1; | ||
| uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData); | ||
| static auto global_tiling_data = at::empty( | ||
| {tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device())); | ||
| if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) { | ||
| aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData, | ||
| tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); | ||
| } else { | ||
| // Handle the case where batchIdx is out of range | ||
| TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx); | ||
| } | ||
| at::Tensor tiling_tensor = | ||
| at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte); | ||
|
|
||
| return std::make_tuple(tiling_tensor, block_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of a static local variable global_tiling_data to cache tiling information on the device is not thread-safe. If multiple threads execute this function concurrently with the same opShape.m, they will write to the same memory location, causing a race condition. This can lead to incorrect tiling data being used by the kernel, resulting in wrong computations or crashes.
Given that PpMatmulTilingData is a small struct, it's safer and cleaner to allocate a new tensor for it on each call. This avoids the race condition and also removes the limitation of MAX_CAPTURE_NUM on the m dimension.
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
auto tiling_tensor = at::empty({(int64_t)tilingSize}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
aclrtMemcpy(tiling_tensor.data_ptr<uint8_t>(), tilingSize, &matmulTilingData,
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
return std::make_tuple(tiling_tensor, block_dim);
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?