Skip to content

Revert "[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762) #21334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ def bench_run(
a, score, topk, renormalize=False
)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -116,10 +111,6 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
Expand All @@ -134,10 +125,6 @@ def run_cutlass_moe(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand All @@ -149,10 +136,6 @@ def run_cutlass_from_graph(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
Expand All @@ -167,10 +150,6 @@ def run_cutlass_from_graph(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand Down Expand Up @@ -215,10 +194,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
Expand Down Expand Up @@ -256,10 +231,6 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -318,10 +289,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
Expand All @@ -330,7 +297,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
53 changes: 11 additions & 42 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,30 +160,6 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}

template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];

if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;

int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;

for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
Expand All @@ -197,24 +173,17 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);

if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}

#else
Expand Down
14 changes: 2 additions & 12 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,6 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To align with the proposed fix for CUDA graph compatibility, the stride tensors need to be passed to cutlass_moe_fp8 for testing.

        'w2_scale': moe_tensors.w2_scale,
        'ab_strides1': moe_tensors.ab_strides1,
        'ab_strides2': moe_tensors.ab_strides2,
        'c_strides1': moe_tensors.c_strides1,
        'c_strides2': moe_tensors.c_strides2,

'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
Expand Down Expand Up @@ -444,11 +440,6 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors need to be created for the test to be consistent with the proposed fix for CUDA graph compatibility.

        expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

        ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)


ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)

activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
Expand All @@ -457,9 +448,8 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False)
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
Comment on lines +451 to +452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors should be passed to run_cutlass_moe_fp8 to align with the proposed fix for CUDA graph compatibility.

            a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
            workspace13, workspace2, None, mt.a.dtype, per_act_token,
            per_out_channel, False)


workspace13.random_()
output_random_workspace = torch.empty(output_shape,
Expand Down
22 changes: 0 additions & 22 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank

num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
Expand Down Expand Up @@ -124,31 +123,10 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)

ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)

experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)
Comment on lines 126 to 131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors need to be created and passed to CutlassExpertsFp8 for the test to be consistent with the proposed fix for CUDA graph compatibility. You'll also need to re-introduce intermediate_dim which was removed in this PR.

    intermediate_dim = w2.shape[2]
    ab_strides1 = torch.full((num_local_experts, ),
                             hidden_dim,
                             device="cuda",
                             dtype=torch.int64)
    ab_strides2 = torch.full((num_local_experts, ),
                             intermediate_dim,
                             device="cuda",
                             dtype=torch.int64)
    c_strides1 = torch.full((num_local_experts, ),
                            2 * intermediate_dim,
                            device="cuda",
                            dtype=torch.int64)
    c_strides2 = torch.full((num_local_experts, ),
                            hidden_dim,
                            device="cuda",
                            dtype=torch.int64)

    experts = CutlassExpertsFp8(num_local_experts,
                                out_dtype,
                                per_act_token,
                                per_out_ch,
                                ab_strides1,
                                ab_strides2,
                                c_strides1,
                                c_strides2,
                                num_dispatchers=num_dispatchers,
                                use_batched_format=True)


Expand Down
62 changes: 23 additions & 39 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache,
extract_required_args)
from vllm.scalar_type import scalar_types
Expand All @@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To maintain CUDA graph compatibility, the stride tensors should be passed as arguments instead of being created inside this function. Please add them back to the function signature.

Suggested change
a2_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,

ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
Expand Down Expand Up @@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)

a1q = ops.shuffle_rows(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
if per_act_token else a1q_scale)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]

ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
Comment on lines +160 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block creates new tensors inside the function, which breaks CUDA graph compatibility as this function can be captured in a CUDA graph (e.g., in benchmarks). This should be removed, and the stride tensors should be passed as arguments to run_cutlass_moe_fp8.


if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
Expand Down Expand Up @@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
non_blocking=True)
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)


# TODO (bnell): split class batched vs. non-batched?
Expand All @@ -211,10 +223,6 @@ def __init__(
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors need to be passed into the CutlassExpertsFp8 constructor to be available for the apply method. Please add them back to the __init__ signature.

        per_out_ch_quant: bool,
        ab_strides1: torch.Tensor,
        ab_strides2: torch.Tensor,
        c_strides1: torch.Tensor,
        c_strides2: torch.Tensor,

ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
Expand All @@ -231,10 +239,6 @@ def __init__(
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors should be stored as instance attributes to be used in the apply method.

        self.out_dtype = out_dtype
        self.ab_strides1 = ab_strides1
        self.ab_strides2 = ab_strides2
        self.c_strides1 = c_strides1
        self.c_strides2 = c_strides2

self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format

@property
Expand Down Expand Up @@ -314,8 +318,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
Comment on lines +321 to 322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors should be passed to run_cutlass_moe_fp8 from the instance attributes.

            a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
            self.c_strides2, workspace13, workspace2, expert_num_tokens,

self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
Expand All @@ -329,10 +332,6 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors should be passed as arguments here as well to maintain CUDA graph compatibility.

    w2_scale: torch.Tensor,
    ab_strides1: torch.Tensor,
    ab_strides2: torch.Tensor,
    c_strides1: torch.Tensor,
    c_strides2: torch.Tensor,

ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -360,17 +359,6 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
Expand Down Expand Up @@ -403,10 +391,6 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stride tensors should be passed to the CutlassExpertsFp8 constructor.

            per_out_ch_quant=per_out_ch,
            ab_strides1=ab_strides1,
            ab_strides2=ab_strides2,
            c_strides1=c_strides1,
            c_strides2=c_strides2,

ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False,
),
)
Expand Down
Loading