Skip to content

Commit 291c923

Browse files
committed
Reapply "[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (vllm-project#20762) (vllm-project#21334)
This reverts commit e7b2042. The original PR vllm-project#20762 is: Authored-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent 35366ae commit 291c923

File tree

6 files changed

+174
-38
lines changed

6 files changed

+174
-38
lines changed

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def bench_run(
8080
a, score, topk, renormalize=False
8181
)
8282

83+
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
84+
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
85+
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
86+
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
87+
8388
def run_triton_moe(
8489
a: torch.Tensor,
8590
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def run_cutlass_moe(
111116
w2: torch.Tensor,
112117
w1_scale: torch.Tensor,
113118
w2_scale: torch.Tensor,
119+
ab_strides1: torch.Tensor,
120+
ab_strides2: torch.Tensor,
121+
c_strides1: torch.Tensor,
122+
c_strides2: torch.Tensor,
114123
topk_weights: torch.Tensor,
115124
topk_ids: torch.Tensor,
116125
per_act_token: bool,
@@ -125,6 +134,10 @@ def run_cutlass_moe(
125134
topk_ids,
126135
w1_scale,
127136
w2_scale,
137+
ab_strides1,
138+
ab_strides2,
139+
c_strides1,
140+
c_strides2,
128141
per_act_token,
129142
a1_scale=None,
130143
)
@@ -136,6 +149,10 @@ def run_cutlass_from_graph(
136149
w2_q: torch.Tensor,
137150
w1_scale: torch.Tensor,
138151
w2_scale: torch.Tensor,
152+
ab_strides1: torch.Tensor,
153+
ab_strides2: torch.Tensor,
154+
c_strides1: torch.Tensor,
155+
c_strides2: torch.Tensor,
139156
topk_weights: torch.Tensor,
140157
topk_ids: torch.Tensor,
141158
):
@@ -150,6 +167,10 @@ def run_cutlass_from_graph(
150167
topk_ids,
151168
w1_scale,
152169
w2_scale,
170+
ab_strides1,
171+
ab_strides2,
172+
c_strides1,
173+
c_strides2,
153174
per_act_token,
154175
a1_scale=None,
155176
)
@@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
194215
w2_q,
195216
w1_scale,
196217
w2_scale,
218+
ab_strides1,
219+
ab_strides2,
220+
c_strides1,
221+
c_strides2,
197222
topk_weights,
198223
topk_ids,
199224
)
@@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
231256
"w1_scale": w1_scale,
232257
"w2_scale": w2_scale,
233258
"per_act_token": per_act_token,
259+
"ab_strides1": ab_strides1,
260+
"ab_strides2": ab_strides2,
261+
"c_strides1": c_strides1,
262+
"c_strides2": c_strides2,
234263
# cuda graph params
235264
"cutlass_graph": cutlass_graph,
236265
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
289318
w2_q,
290319
w1_scale,
291320
w2_scale,
321+
ab_strides1,
322+
ab_strides2,
323+
c_strides1,
324+
c_strides2,
292325
topk_weights,
293326
topk_ids,
294327
per_act_token,
@@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):
297330

298331
results.append(
299332
benchmark.Timer(
300-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
333+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
301334
globals=globals,
302335
label=label,
303336
sub_label=sub_label,

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
160160
}
161161
}
162162

163+
template <typename T>
164+
__global__ void shuffleInputRowsKernelSlow(const T* input,
165+
const int32_t* dst2src_map,
166+
T* output, int64_t num_src_rows,
167+
int64_t num_dst_rows,
168+
int64_t num_cols) {
169+
int64_t dest_row_idx = blockIdx.x;
170+
int64_t const source_row_idx = dst2src_map[dest_row_idx];
171+
172+
if (blockIdx.x < num_dst_rows) {
173+
// Duplicate and permute rows
174+
auto const* source_row_ptr = input + source_row_idx * num_cols;
175+
auto* dest_row_ptr = output + dest_row_idx * num_cols;
176+
177+
int64_t const start_offset = threadIdx.x;
178+
int64_t const stride = blockDim.x;
179+
180+
for (int elem_index = start_offset; elem_index < num_cols;
181+
elem_index += stride) {
182+
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
183+
}
184+
}
185+
}
186+
163187
void shuffle_rows(const torch::Tensor& input_tensor,
164188
const torch::Tensor& dst2src_map,
165189
torch::Tensor& output_tensor) {
@@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
173197
int64_t const num_src_rows = input_tensor.size(0);
174198
int64_t const num_cols = input_tensor.size(1);
175199

176-
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
177-
"num_cols must be divisible by 128 / "
178-
"sizeof(input_tensor.scalar_type()) / 8");
179-
180-
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
181-
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
182-
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
183-
dst2src_map.data_ptr<int32_t>(),
184-
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
185-
num_dest_rows, num_cols);
186-
});
200+
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
201+
// use slow kernel if num_cols can't be aligned to 128 bits
202+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
203+
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
204+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
205+
dst2src_map.data_ptr<int32_t>(),
206+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
207+
num_dest_rows, num_cols);
208+
});
209+
} else {
210+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
211+
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
212+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
213+
dst2src_map.data_ptr<int32_t>(),
214+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
215+
num_dest_rows, num_cols);
216+
});
217+
}
187218
}
188219

189220
#else

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
207207
'topk_ids': topk_ids,
208208
'w1_scale': moe_tensors.w1_scale,
209209
'w2_scale': moe_tensors.w2_scale,
210+
'ab_strides1': moe_tensors.ab_strides1,
211+
'ab_strides2': moe_tensors.ab_strides2,
212+
'c_strides1': moe_tensors.c_strides1,
213+
'c_strides2': moe_tensors.c_strides2,
210214
'per_act_token': per_act_token,
211215
'a1_scale': None #moe_tensors.a_scale
212216
}
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
440444
expert_map[start:end] = list(range(num_local_experts))
441445
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
442446

447+
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
448+
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
449+
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
450+
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
451+
443452
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
444453
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
445454
torch.float8_e4m3fn,
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
448457
func = lambda output: run_cutlass_moe_fp8(
449458
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
450459
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
451-
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
452-
per_act_token, per_out_channel, False)
460+
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
461+
workspace13, workspace2, None, mt.a.dtype, per_act_token,
462+
per_out_channel, False)
453463

454464
workspace13.random_()
455465
output_random_workspace = torch.empty(output_shape,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def pplx_cutlass_moe(
7575
assert torch.cuda.current_device() == pgi.local_rank
7676

7777
num_tokens, hidden_dim = a.shape
78+
intermediate_dim = w2.shape[2]
7879
num_experts = w1.shape[0]
7980
block_size = hidden_dim # TODO support more cases
8081
device = pgi.device
@@ -123,10 +124,31 @@ def pplx_cutlass_moe(
123124
num_local_experts=num_local_experts,
124125
num_dispatchers=num_dispatchers)
125126

127+
ab_strides1 = torch.full((num_local_experts, ),
128+
hidden_dim,
129+
device="cuda",
130+
dtype=torch.int64)
131+
ab_strides2 = torch.full((num_local_experts, ),
132+
intermediate_dim,
133+
device="cuda",
134+
dtype=torch.int64)
135+
c_strides1 = torch.full((num_local_experts, ),
136+
2 * intermediate_dim,
137+
device="cuda",
138+
dtype=torch.int64)
139+
c_strides2 = torch.full((num_local_experts, ),
140+
hidden_dim,
141+
device="cuda",
142+
dtype=torch.int64)
143+
126144
experts = CutlassExpertsFp8(num_local_experts,
127145
out_dtype,
128146
per_act_token,
129147
per_out_ch,
148+
ab_strides1,
149+
ab_strides2,
150+
c_strides1,
151+
c_strides2,
130152
num_dispatchers=num_dispatchers,
131153
use_batched_format=True)
132154

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
MoEPrepareAndFinalizeNoEP)
1414
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1515
TopKWeightAndReduceDelegate)
16-
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
17-
_fp8_quantize,
16+
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
1817
_resize_cache,
1918
extract_required_args)
2019
from vllm.scalar_type import scalar_types
@@ -35,6 +34,10 @@ def run_cutlass_moe_fp8(
3534
w2_scale: Optional[torch.Tensor],
3635
a1q_scale: Optional[torch.Tensor],
3736
a2_scale: Optional[torch.Tensor],
37+
ab_strides1: torch.Tensor,
38+
ab_strides2: torch.Tensor,
39+
c_strides1: torch.Tensor,
40+
c_strides2: torch.Tensor,
3841
workspace13: torch.Tensor,
3942
workspace2: torch.Tensor,
4043
expert_num_tokens: Optional[torch.Tensor],
@@ -153,27 +156,11 @@ def run_cutlass_moe_fp8(
153156
problem_sizes1, problem_sizes2, a_map,
154157
c_map, global_num_experts, N, K)
155158

156-
a1q = _fp8_perm(a1q, a_map)
157-
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
159+
a1q = ops.shuffle_rows(a1q, a_map)
160+
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
161+
if per_act_token else a1q_scale)
158162
expert_offsets = expert_offsets[:-1]
159163

160-
ab_strides1 = torch.full((w1.size(0), ),
161-
K,
162-
device=device,
163-
dtype=torch.int64)
164-
c_strides1 = torch.full((w1.size(0), ),
165-
2 * N,
166-
device=device,
167-
dtype=torch.int64)
168-
ab_strides2 = torch.full((w1.size(0), ),
169-
N,
170-
device=device,
171-
dtype=torch.int64)
172-
c_strides2 = torch.full((w1.size(0), ),
173-
K,
174-
device=device,
175-
dtype=torch.int64)
176-
177164
if use_batched_format:
178165
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
179166
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
@@ -210,7 +197,8 @@ def run_cutlass_moe_fp8(
210197
else:
211198
# We can't do this inplace because output may point to the same tensor
212199
# as c3.
213-
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
200+
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
201+
non_blocking=True)
214202

215203

216204
# TODO (bnell): split class batched vs. non-batched?
@@ -223,6 +211,10 @@ def __init__(
223211
out_dtype: Optional[torch.dtype],
224212
per_act_token_quant: bool,
225213
per_out_ch_quant: bool,
214+
ab_strides1: torch.Tensor,
215+
ab_strides2: torch.Tensor,
216+
c_strides1: torch.Tensor,
217+
c_strides2: torch.Tensor,
226218
block_shape: Optional[list[int]] = None,
227219
num_dispatchers: Optional[int] = None,
228220
use_batched_format: bool = False,
@@ -239,6 +231,10 @@ def __init__(
239231
self.max_experts_per_worker = max_experts_per_worker
240232
self.num_dispatchers = num_dispatchers
241233
self.out_dtype = out_dtype
234+
self.ab_strides1 = ab_strides1
235+
self.ab_strides2 = ab_strides2
236+
self.c_strides1 = c_strides1
237+
self.c_strides2 = c_strides2
242238
self.use_batched_format = use_batched_format
243239

244240
@property
@@ -318,7 +314,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
318314
run_cutlass_moe_fp8(
319315
output, hidden_states, w1, w2, topk_ids, activation_callable,
320316
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
321-
a2_scale, workspace13, workspace2, expert_num_tokens,
317+
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
318+
self.c_strides2, workspace13, workspace2, expert_num_tokens,
322319
self.out_dtype if self.out_dtype is not None else in_dtype,
323320
self.per_act_token_quant, self.per_out_ch_quant,
324321
self.use_batched_format)
@@ -332,6 +329,10 @@ def cutlass_moe_fp8(
332329
topk_ids: torch.Tensor,
333330
w1_scale: torch.Tensor,
334331
w2_scale: torch.Tensor,
332+
ab_strides1: torch.Tensor,
333+
ab_strides2: torch.Tensor,
334+
c_strides1: torch.Tensor,
335+
c_strides2: torch.Tensor,
335336
per_act_token: Optional[bool] = None,
336337
activation: str = "silu",
337338
a1_scale: Optional[torch.Tensor] = None,
@@ -359,6 +360,17 @@ def cutlass_moe_fp8(
359360
Shape: [num_experts] or [num_experts, 2N]
360361
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
361362
Shape: [num_experts] or [num_experts, K]
363+
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
364+
Shape: [num_experts]
365+
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
366+
Shape: [num_experts]
367+
- c_strides1 (torch.Tensor): The output strides for the first gemm.
368+
Shape: [num_experts]
369+
- c_strides2 (torch.Tensor): The output strides for the second gemm.
370+
Shape: [num_experts]
371+
- per_act_token (Optional[bool]): Whether the scale is per-token or
372+
per-tensor.
373+
- activation (str): The activation function to use.
362374
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
363375
Shape: scalar or [M]
364376
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -391,6 +403,10 @@ def cutlass_moe_fp8(
391403
out_dtype=a.dtype,
392404
per_act_token_quant=per_act_token,
393405
per_out_ch_quant=per_out_ch,
406+
ab_strides1=ab_strides1,
407+
ab_strides2=ab_strides2,
408+
c_strides1=c_strides1,
409+
c_strides2=c_strides2,
394410
use_batched_format=False,
395411
),
396412
)

0 commit comments

Comments
 (0)