Skip to content

[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
35 changes: 34 additions & 1 deletion benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -111,6 +116,10 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
Expand All @@ -125,6 +134,10 @@ def run_cutlass_moe(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand All @@ -136,6 +149,10 @@ def run_cutlass_from_graph(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
Expand All @@ -150,6 +167,10 @@ def run_cutlass_from_graph(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand Down Expand Up @@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
Expand Down Expand Up @@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
Expand All @@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
53 changes: 42 additions & 11 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}

template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];

if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;

int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;

for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
Expand All @@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);

TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
}

#else
Expand Down
14 changes: 12 additions & 2 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
Expand Down Expand Up @@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
Comment on lines +446 to +449
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we save on memory by using the same tensor for both ab_strides1 and cstrides2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should be possible. I'll push an update


activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
Expand All @@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False)

workspace13.random_()
output_random_workspace = torch.empty(output_shape,
Expand Down
22 changes: 22 additions & 0 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank

num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
Expand Down Expand Up @@ -123,10 +124,31 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)

ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)

experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)

Expand Down
62 changes: 39 additions & 23 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types

Expand All @@ -32,6 +31,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
Expand Down Expand Up @@ -150,27 +153,11 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)

a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
a1q = ops.shuffle_rows(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
if per_act_token else a1q_scale)
Comment on lines -155 to +160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the speedup mentioned in this statement

Faster kernels for shuffling hidden_states, input scales and outputs of the function that executes CUTLASS MoE layer.

come from being able to use shuffle_rows when the num_cols % 128 == 0 condition is not guaranteed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from using these custom kernels instead of the pytorch index-based function. The "slow" kernels are expected to only be run to shuffle scales.

expert_offsets = expert_offsets[:-1]

ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)

if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
Expand Down Expand Up @@ -207,7 +194,8 @@ def run_cutlass_moe_fp8(
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
non_blocking=True)


# TODO (bnell): split class batched vs. non-batched?
Expand All @@ -220,6 +208,10 @@ def __init__(
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
Expand All @@ -236,6 +228,10 @@ def __init__(
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format

@property
Expand Down Expand Up @@ -312,7 +308,8 @@ def apply(
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
Expand All @@ -326,6 +323,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -353,6 +354,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
Expand Down Expand Up @@ -385,6 +397,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False,
),
)
Expand Down
Loading