Skip to content

Commit 9fb2d22

Browse files
authored
[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
1 parent 2d6a382 commit 9fb2d22

File tree

6 files changed

+174
-38
lines changed

6 files changed

+174
-38
lines changed

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def bench_run(
8080
a, score, topk, renormalize=False
8181
)
8282

83+
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
84+
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
85+
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
86+
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
87+
8388
def run_triton_moe(
8489
a: torch.Tensor,
8590
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def run_cutlass_moe(
111116
w2: torch.Tensor,
112117
w1_scale: torch.Tensor,
113118
w2_scale: torch.Tensor,
119+
ab_strides1: torch.Tensor,
120+
ab_strides2: torch.Tensor,
121+
c_strides1: torch.Tensor,
122+
c_strides2: torch.Tensor,
114123
topk_weights: torch.Tensor,
115124
topk_ids: torch.Tensor,
116125
per_act_token: bool,
@@ -125,6 +134,10 @@ def run_cutlass_moe(
125134
topk_ids,
126135
w1_scale,
127136
w2_scale,
137+
ab_strides1,
138+
ab_strides2,
139+
c_strides1,
140+
c_strides2,
128141
per_act_token,
129142
a1_scale=None,
130143
)
@@ -136,6 +149,10 @@ def run_cutlass_from_graph(
136149
w2_q: torch.Tensor,
137150
w1_scale: torch.Tensor,
138151
w2_scale: torch.Tensor,
152+
ab_strides1: torch.Tensor,
153+
ab_strides2: torch.Tensor,
154+
c_strides1: torch.Tensor,
155+
c_strides2: torch.Tensor,
139156
topk_weights: torch.Tensor,
140157
topk_ids: torch.Tensor,
141158
):
@@ -150,6 +167,10 @@ def run_cutlass_from_graph(
150167
topk_ids,
151168
w1_scale,
152169
w2_scale,
170+
ab_strides1,
171+
ab_strides2,
172+
c_strides1,
173+
c_strides2,
153174
per_act_token,
154175
a1_scale=None,
155176
)
@@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
194215
w2_q,
195216
w1_scale,
196217
w2_scale,
218+
ab_strides1,
219+
ab_strides2,
220+
c_strides1,
221+
c_strides2,
197222
topk_weights,
198223
topk_ids,
199224
)
@@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
231256
"w1_scale": w1_scale,
232257
"w2_scale": w2_scale,
233258
"per_act_token": per_act_token,
259+
"ab_strides1": ab_strides1,
260+
"ab_strides2": ab_strides2,
261+
"c_strides1": c_strides1,
262+
"c_strides2": c_strides2,
234263
# cuda graph params
235264
"cutlass_graph": cutlass_graph,
236265
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
289318
w2_q,
290319
w1_scale,
291320
w2_scale,
321+
ab_strides1,
322+
ab_strides2,
323+
c_strides1,
324+
c_strides2,
292325
topk_weights,
293326
topk_ids,
294327
per_act_token,
@@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):
297330

298331
results.append(
299332
benchmark.Timer(
300-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
333+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
301334
globals=globals,
302335
label=label,
303336
sub_label=sub_label,

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
160160
}
161161
}
162162

163+
template <typename T>
164+
__global__ void shuffleInputRowsKernelSlow(const T* input,
165+
const int32_t* dst2src_map,
166+
T* output, int64_t num_src_rows,
167+
int64_t num_dst_rows,
168+
int64_t num_cols) {
169+
int64_t dest_row_idx = blockIdx.x;
170+
int64_t const source_row_idx = dst2src_map[dest_row_idx];
171+
172+
if (blockIdx.x < num_dst_rows) {
173+
// Duplicate and permute rows
174+
auto const* source_row_ptr = input + source_row_idx * num_cols;
175+
auto* dest_row_ptr = output + dest_row_idx * num_cols;
176+
177+
int64_t const start_offset = threadIdx.x;
178+
int64_t const stride = blockDim.x;
179+
180+
for (int elem_index = start_offset; elem_index < num_cols;
181+
elem_index += stride) {
182+
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
183+
}
184+
}
185+
}
186+
163187
void shuffle_rows(const torch::Tensor& input_tensor,
164188
const torch::Tensor& dst2src_map,
165189
torch::Tensor& output_tensor) {
@@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
173197
int64_t const num_src_rows = input_tensor.size(0);
174198
int64_t const num_cols = input_tensor.size(1);
175199

176-
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
177-
"num_cols must be divisible by 128 / "
178-
"sizeof(input_tensor.scalar_type()) / 8");
179-
180-
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
181-
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
182-
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
183-
dst2src_map.data_ptr<int32_t>(),
184-
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
185-
num_dest_rows, num_cols);
186-
});
200+
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
201+
// use slow kernel if num_cols can't be aligned to 128 bits
202+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
203+
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
204+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
205+
dst2src_map.data_ptr<int32_t>(),
206+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
207+
num_dest_rows, num_cols);
208+
});
209+
} else {
210+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
211+
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
212+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
213+
dst2src_map.data_ptr<int32_t>(),
214+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
215+
num_dest_rows, num_cols);
216+
});
217+
}
187218
}
188219

189220
#else

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
206206
'topk_ids': topk_ids,
207207
'w1_scale': moe_tensors.w1_scale,
208208
'w2_scale': moe_tensors.w2_scale,
209+
'ab_strides1': moe_tensors.ab_strides1,
210+
'ab_strides2': moe_tensors.ab_strides2,
211+
'c_strides1': moe_tensors.c_strides1,
212+
'c_strides2': moe_tensors.c_strides2,
209213
'per_act_token': per_act_token,
210214
'a1_scale': None #moe_tensors.a_scale
211215
}
@@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8(
439443
expert_map[start:end] = list(range(num_local_experts))
440444
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
441445

446+
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
447+
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
448+
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
449+
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
450+
442451
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
443452
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
444453
torch.float8_e4m3fn,
@@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8(
447456
func = lambda output: run_cutlass_moe_fp8(
448457
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
449458
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
450-
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
451-
per_act_token, per_out_channel, False)
459+
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
460+
workspace13, workspace2, None, mt.a.dtype, per_act_token,
461+
per_out_channel, False)
452462

453463
workspace13.random_()
454464
output_random_workspace = torch.empty(output_shape,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def pplx_cutlass_moe(
7575
assert torch.cuda.current_device() == pgi.local_rank
7676

7777
num_tokens, hidden_dim = a.shape
78+
intermediate_dim = w2.shape[2]
7879
num_experts = w1.shape[0]
7980
block_size = hidden_dim # TODO support more cases
8081
device = pgi.device
@@ -123,10 +124,31 @@ def pplx_cutlass_moe(
123124
num_local_experts=num_local_experts,
124125
num_dispatchers=num_dispatchers)
125126

127+
ab_strides1 = torch.full((num_local_experts, ),
128+
hidden_dim,
129+
device="cuda",
130+
dtype=torch.int64)
131+
ab_strides2 = torch.full((num_local_experts, ),
132+
intermediate_dim,
133+
device="cuda",
134+
dtype=torch.int64)
135+
c_strides1 = torch.full((num_local_experts, ),
136+
2 * intermediate_dim,
137+
device="cuda",
138+
dtype=torch.int64)
139+
c_strides2 = torch.full((num_local_experts, ),
140+
hidden_dim,
141+
device="cuda",
142+
dtype=torch.int64)
143+
126144
experts = CutlassExpertsFp8(num_local_experts,
127145
out_dtype,
128146
per_act_token,
129147
per_out_ch,
148+
ab_strides1,
149+
ab_strides2,
150+
c_strides1,
151+
c_strides2,
130152
num_dispatchers=num_dispatchers,
131153
use_batched_format=True)
132154

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
MoEPrepareAndFinalizeNoEP)
1414
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1515
TopKWeightAndReduceDelegate)
16-
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
17-
_fp8_quantize,
16+
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
1817
_resize_cache)
1918
from vllm.scalar_type import scalar_types
2019

@@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
3433
w2_scale: Optional[torch.Tensor],
3534
a1q_scale: Optional[torch.Tensor],
3635
a2_scale: Optional[torch.Tensor],
36+
ab_strides1: torch.Tensor,
37+
ab_strides2: torch.Tensor,
38+
c_strides1: torch.Tensor,
39+
c_strides2: torch.Tensor,
3740
workspace13: torch.Tensor,
3841
workspace2: torch.Tensor,
3942
expert_num_tokens: Optional[torch.Tensor],
@@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
152155
problem_sizes1, problem_sizes2, a_map,
153156
c_map, global_num_experts, N, K)
154157

155-
a1q = _fp8_perm(a1q, a_map)
156-
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
158+
a1q = ops.shuffle_rows(a1q, a_map)
159+
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
160+
if per_act_token else a1q_scale)
157161
expert_offsets = expert_offsets[:-1]
158162

159-
ab_strides1 = torch.full((w1.size(0), ),
160-
K,
161-
device=device,
162-
dtype=torch.int64)
163-
c_strides1 = torch.full((w1.size(0), ),
164-
2 * N,
165-
device=device,
166-
dtype=torch.int64)
167-
ab_strides2 = torch.full((w1.size(0), ),
168-
N,
169-
device=device,
170-
dtype=torch.int64)
171-
c_strides2 = torch.full((w1.size(0), ),
172-
K,
173-
device=device,
174-
dtype=torch.int64)
175-
176163
if use_batched_format:
177164
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
178165
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
@@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
209196
else:
210197
# We can't do this inplace because output may point to the same tensor
211198
# as c3.
212-
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
199+
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
200+
non_blocking=True)
213201

214202

215203
# TODO (bnell): split class batched vs. non-batched?
@@ -222,6 +210,10 @@ def __init__(
222210
out_dtype: Optional[torch.dtype],
223211
per_act_token_quant: bool,
224212
per_out_ch_quant: bool,
213+
ab_strides1: torch.Tensor,
214+
ab_strides2: torch.Tensor,
215+
c_strides1: torch.Tensor,
216+
c_strides2: torch.Tensor,
225217
block_shape: Optional[list[int]] = None,
226218
num_dispatchers: Optional[int] = None,
227219
use_batched_format: bool = False,
@@ -238,6 +230,10 @@ def __init__(
238230
self.max_experts_per_worker = max_experts_per_worker
239231
self.num_dispatchers = num_dispatchers
240232
self.out_dtype = out_dtype
233+
self.ab_strides1 = ab_strides1
234+
self.ab_strides2 = ab_strides2
235+
self.c_strides1 = c_strides1
236+
self.c_strides2 = c_strides2
241237
self.use_batched_format = use_batched_format
242238

243239
@property
@@ -316,7 +312,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
316312
run_cutlass_moe_fp8(
317313
output, hidden_states, w1, w2, topk_ids, activation_callable,
318314
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
319-
a2_scale, workspace13, workspace2, expert_num_tokens,
315+
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
316+
self.c_strides2, workspace13, workspace2, expert_num_tokens,
320317
self.out_dtype if self.out_dtype is not None else in_dtype,
321318
self.per_act_token_quant, self.per_out_ch_quant,
322319
self.use_batched_format)
@@ -330,6 +327,10 @@ def cutlass_moe_fp8(
330327
topk_ids: torch.Tensor,
331328
w1_scale: torch.Tensor,
332329
w2_scale: torch.Tensor,
330+
ab_strides1: torch.Tensor,
331+
ab_strides2: torch.Tensor,
332+
c_strides1: torch.Tensor,
333+
c_strides2: torch.Tensor,
333334
per_act_token: Optional[bool] = None,
334335
activation: str = "silu",
335336
a1_scale: Optional[torch.Tensor] = None,
@@ -357,6 +358,17 @@ def cutlass_moe_fp8(
357358
Shape: [num_experts] or [num_experts, 2N]
358359
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
359360
Shape: [num_experts] or [num_experts, K]
361+
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
362+
Shape: [num_experts]
363+
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
364+
Shape: [num_experts]
365+
- c_strides1 (torch.Tensor): The output strides for the first gemm.
366+
Shape: [num_experts]
367+
- c_strides2 (torch.Tensor): The output strides for the second gemm.
368+
Shape: [num_experts]
369+
- per_act_token (Optional[bool]): Whether the scale is per-token or
370+
per-tensor.
371+
- activation (str): The activation function to use.
360372
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
361373
Shape: scalar or [M]
362374
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -389,6 +401,10 @@ def cutlass_moe_fp8(
389401
out_dtype=a.dtype,
390402
per_act_token_quant=per_act_token,
391403
per_out_ch_quant=per_out_ch,
404+
ab_strides1=ab_strides1,
405+
ab_strides2=ab_strides2,
406+
c_strides1=c_strides1,
407+
c_strides2=c_strides2,
392408
use_batched_format=False,
393409
),
394410
)

0 commit comments

Comments
 (0)