-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Revert "[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762) #21334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…UTLASS MoE (vllm-project#20762)" This reverts commit 9fb2d22. Signed-off-by: Ming Yang <minos.future@gmail.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request reverts a previous performance improvement to fix a correctness issue. The changes are mostly removing code related to the optimization. However, I've identified a critical issue where the revert breaks CUDA graph compatibility by creating new tensors inside a function that can be captured by a CUDA graph. This will cause benchmarks and potentially other features relying on CUDA graphs to fail. I've provided detailed comments and code suggestions across multiple files to address this by re-introducing the practice of passing stride tensors as arguments, which was the behavior before the original performance-enhancing change.
@@ -207,10 +207,6 @@ def run_8_bit(moe_tensors: MOETensors8Bit, | |||
'topk_ids': topk_ids, | |||
'w1_scale': moe_tensors.w1_scale, | |||
'w2_scale': moe_tensors.w2_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To align with the proposed fix for CUDA graph compatibility, the stride tensors need to be passed to cutlass_moe_fp8
for testing.
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
@@ -444,11 +440,6 @@ def test_run_cutlass_moe_fp8( | |||
expert_map[start:end] = list(range(num_local_experts)) | |||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride tensors need to be created for the test to be consistent with the proposed fix for CUDA graph compatibility.
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, | ||
per_act_token, per_out_channel, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
experts = CutlassExpertsFp8(num_local_experts, | ||
out_dtype, | ||
per_act_token, | ||
per_out_ch, | ||
ab_strides1, | ||
ab_strides2, | ||
c_strides1, | ||
c_strides2, | ||
num_dispatchers=num_dispatchers, | ||
use_batched_format=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride tensors need to be created and passed to CutlassExpertsFp8
for the test to be consistent with the proposed fix for CUDA graph compatibility. You'll also need to re-introduce intermediate_dim
which was removed in this PR.
intermediate_dim = w2.shape[2]
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)
@@ -34,10 +35,6 @@ def run_cutlass_moe_fp8( | |||
w2_scale: Optional[torch.Tensor], | |||
a1q_scale: Optional[torch.Tensor], | |||
a2_scale: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To maintain CUDA graph compatibility, the stride tensors should be passed as arguments instead of being created inside this function. Please add them back to the function signature.
a2_scale: Optional[torch.Tensor], | |
a2_scale: Optional[torch.Tensor], | |
ab_strides1: torch.Tensor, | |
ab_strides2: torch.Tensor, | |
c_strides1: torch.Tensor, | |
c_strides2: torch.Tensor, |
@@ -329,10 +332,6 @@ def cutlass_moe_fp8( | |||
topk_ids: torch.Tensor, | |||
w1_scale: torch.Tensor, | |||
w2_scale: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -403,10 +391,6 @@ def cutlass_moe_fp8( | |||
out_dtype=a.dtype, | |||
per_act_token_quant=per_act_token, | |||
per_out_ch_quant=per_out_ch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -859,21 +859,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, | |||
requires_grad=False) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride tensors should be pre-computed and stored here to be passed to the MoE kernel. This is necessary for CUDA graph compatibility.
device = layer.w13_weight.device | |
# ab_strides1 and c_strides2 are the same | |
self.ab_strides1_c_strides2 = torch.full((layer.local_num_experts, ), | |
layer.hidden_size, | |
device=device, | |
dtype=torch.int64) | |
self.ab_strides2 = torch.full((layer.local_num_experts, ), | |
layer.intermediate_size_per_partition, | |
device=device, | |
dtype=torch.int64) | |
self.c_strides1 = torch.full((layer.local_num_experts, ), | |
2 * layer.intermediate_size_per_partition, | |
device=device, | |
dtype=torch.int64) | |
@@ -896,10 +881,6 @@ def select_gemm_impl( | |||
moe.in_dtype, | |||
self.input_quant.strategy == QuantizationStrategy.TOKEN, | |||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -968,10 +948,6 @@ def apply( | |||
expert_map=None if self.disable_expert_map else expert_map, | |||
w1_scale=layer.w13_weight_scale, | |||
w2_scale=layer.w2_weight_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reverting the original PR to help recover the trunk health. This will unblock our code sync as well.
cc: @ElizaWszola, @tlrmchlsmth, @mgoin , @robertgshaw2-redhat this is blocking our internal work, so need to revert for now to unblock. Sorry about the inconvenience, and happy to help on landing the fixed version. Also if forward-fix is easy to land, we are happy to switch to that as well. :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's revert for now. Thanks for identifying this
…CUTLASS MoE (vllm-project#20762) (vllm-project#21334) This reverts commit e7b2042.
…CUTLASS MoE (vllm-project#20762) (vllm-project#21334) This reverts commit e7b2042. The original PR vllm-project#20762 is: Authored-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com>
…UTLASS MoE (vllm-project#20762) (vllm-project#21334) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: qizixi <qizixi@meta.com>
…UTLASS MoE (vllm-project#20762) (vllm-project#21334) Signed-off-by: Ming Yang <minos.future@gmail.com>
Purpose
This reverts commit 9fb2d22 to fix #21322
Test Plan
Test Result
local-chat-completions (model=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,base_url=http://127.0.0.1:8000/v1/chat/completions,num_concurrent=32), gen_kwargs: (None), limit: 200.0, num_fewshot: 5, batch_size: 1
(Optional) Documentation Update