Skip to content

Commit e7b2042

Browse files
authored
Revert "[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762) (#21334)
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent 90f1e55 commit e7b2042

File tree

6 files changed

+38
-174
lines changed

6 files changed

+38
-174
lines changed

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,6 @@ def bench_run(
8080
a, score, topk, renormalize=False
8181
)
8282

83-
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
84-
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
85-
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
86-
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
87-
8883
def run_triton_moe(
8984
a: torch.Tensor,
9085
w1: torch.Tensor,
@@ -116,10 +111,6 @@ def run_cutlass_moe(
116111
w2: torch.Tensor,
117112
w1_scale: torch.Tensor,
118113
w2_scale: torch.Tensor,
119-
ab_strides1: torch.Tensor,
120-
ab_strides2: torch.Tensor,
121-
c_strides1: torch.Tensor,
122-
c_strides2: torch.Tensor,
123114
topk_weights: torch.Tensor,
124115
topk_ids: torch.Tensor,
125116
per_act_token: bool,
@@ -134,10 +125,6 @@ def run_cutlass_moe(
134125
topk_ids,
135126
w1_scale,
136127
w2_scale,
137-
ab_strides1,
138-
ab_strides2,
139-
c_strides1,
140-
c_strides2,
141128
per_act_token,
142129
a1_scale=None,
143130
)
@@ -149,10 +136,6 @@ def run_cutlass_from_graph(
149136
w2_q: torch.Tensor,
150137
w1_scale: torch.Tensor,
151138
w2_scale: torch.Tensor,
152-
ab_strides1: torch.Tensor,
153-
ab_strides2: torch.Tensor,
154-
c_strides1: torch.Tensor,
155-
c_strides2: torch.Tensor,
156139
topk_weights: torch.Tensor,
157140
topk_ids: torch.Tensor,
158141
):
@@ -167,10 +150,6 @@ def run_cutlass_from_graph(
167150
topk_ids,
168151
w1_scale,
169152
w2_scale,
170-
ab_strides1,
171-
ab_strides2,
172-
c_strides1,
173-
c_strides2,
174153
per_act_token,
175154
a1_scale=None,
176155
)
@@ -215,10 +194,6 @@ def replay_graph(graph, num_repeats):
215194
w2_q,
216195
w1_scale,
217196
w2_scale,
218-
ab_strides1,
219-
ab_strides2,
220-
c_strides1,
221-
c_strides2,
222197
topk_weights,
223198
topk_ids,
224199
)
@@ -256,10 +231,6 @@ def replay_graph(graph, num_repeats):
256231
"w1_scale": w1_scale,
257232
"w2_scale": w2_scale,
258233
"per_act_token": per_act_token,
259-
"ab_strides1": ab_strides1,
260-
"ab_strides2": ab_strides2,
261-
"c_strides1": c_strides1,
262-
"c_strides2": c_strides2,
263234
# cuda graph params
264235
"cutlass_graph": cutlass_graph,
265236
"triton_graph": triton_graph,
@@ -318,10 +289,6 @@ def replay_graph(graph, num_repeats):
318289
w2_q,
319290
w1_scale,
320291
w2_scale,
321-
ab_strides1,
322-
ab_strides2,
323-
c_strides1,
324-
c_strides2,
325292
topk_weights,
326293
topk_ids,
327294
per_act_token,
@@ -330,7 +297,7 @@ def replay_graph(graph, num_repeats):
330297

331298
results.append(
332299
benchmark.Timer(
333-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
300+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
334301
globals=globals,
335302
label=label,
336303
sub_label=sub_label,

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -160,30 +160,6 @@ __global__ void shuffleInputRowsKernel(const T* input,
160160
}
161161
}
162162

163-
template <typename T>
164-
__global__ void shuffleInputRowsKernelSlow(const T* input,
165-
const int32_t* dst2src_map,
166-
T* output, int64_t num_src_rows,
167-
int64_t num_dst_rows,
168-
int64_t num_cols) {
169-
int64_t dest_row_idx = blockIdx.x;
170-
int64_t const source_row_idx = dst2src_map[dest_row_idx];
171-
172-
if (blockIdx.x < num_dst_rows) {
173-
// Duplicate and permute rows
174-
auto const* source_row_ptr = input + source_row_idx * num_cols;
175-
auto* dest_row_ptr = output + dest_row_idx * num_cols;
176-
177-
int64_t const start_offset = threadIdx.x;
178-
int64_t const stride = blockDim.x;
179-
180-
for (int elem_index = start_offset; elem_index < num_cols;
181-
elem_index += stride) {
182-
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
183-
}
184-
}
185-
}
186-
187163
void shuffle_rows(const torch::Tensor& input_tensor,
188164
const torch::Tensor& dst2src_map,
189165
torch::Tensor& output_tensor) {
@@ -197,24 +173,17 @@ void shuffle_rows(const torch::Tensor& input_tensor,
197173
int64_t const num_src_rows = input_tensor.size(0);
198174
int64_t const num_cols = input_tensor.size(1);
199175

200-
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
201-
// use slow kernel if num_cols can't be aligned to 128 bits
202-
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
203-
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
204-
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
205-
dst2src_map.data_ptr<int32_t>(),
206-
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
207-
num_dest_rows, num_cols);
208-
});
209-
} else {
210-
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
211-
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
212-
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
213-
dst2src_map.data_ptr<int32_t>(),
214-
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
215-
num_dest_rows, num_cols);
216-
});
217-
}
176+
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
177+
"num_cols must be divisible by 128 / "
178+
"sizeof(input_tensor.scalar_type()) / 8");
179+
180+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
181+
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
182+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
183+
dst2src_map.data_ptr<int32_t>(),
184+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
185+
num_dest_rows, num_cols);
186+
});
218187
}
219188

220189
#else

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,6 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
207207
'topk_ids': topk_ids,
208208
'w1_scale': moe_tensors.w1_scale,
209209
'w2_scale': moe_tensors.w2_scale,
210-
'ab_strides1': moe_tensors.ab_strides1,
211-
'ab_strides2': moe_tensors.ab_strides2,
212-
'c_strides1': moe_tensors.c_strides1,
213-
'c_strides2': moe_tensors.c_strides2,
214210
'per_act_token': per_act_token,
215211
'a1_scale': None #moe_tensors.a_scale
216212
}
@@ -444,11 +440,6 @@ def test_run_cutlass_moe_fp8(
444440
expert_map[start:end] = list(range(num_local_experts))
445441
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
446442

447-
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
448-
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
449-
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
450-
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
451-
452443
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
453444
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
454445
torch.float8_e4m3fn,
@@ -457,9 +448,8 @@ def test_run_cutlass_moe_fp8(
457448
func = lambda output: run_cutlass_moe_fp8(
458449
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
459450
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
460-
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
461-
workspace13, workspace2, None, mt.a.dtype, per_act_token,
462-
per_out_channel, False)
451+
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
452+
per_act_token, per_out_channel, False)
463453

464454
workspace13.random_()
465455
output_random_workspace = torch.empty(output_shape,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def pplx_cutlass_moe(
7575
assert torch.cuda.current_device() == pgi.local_rank
7676

7777
num_tokens, hidden_dim = a.shape
78-
intermediate_dim = w2.shape[2]
7978
num_experts = w1.shape[0]
8079
block_size = hidden_dim # TODO support more cases
8180
device = pgi.device
@@ -124,31 +123,10 @@ def pplx_cutlass_moe(
124123
num_local_experts=num_local_experts,
125124
num_dispatchers=num_dispatchers)
126125

127-
ab_strides1 = torch.full((num_local_experts, ),
128-
hidden_dim,
129-
device="cuda",
130-
dtype=torch.int64)
131-
ab_strides2 = torch.full((num_local_experts, ),
132-
intermediate_dim,
133-
device="cuda",
134-
dtype=torch.int64)
135-
c_strides1 = torch.full((num_local_experts, ),
136-
2 * intermediate_dim,
137-
device="cuda",
138-
dtype=torch.int64)
139-
c_strides2 = torch.full((num_local_experts, ),
140-
hidden_dim,
141-
device="cuda",
142-
dtype=torch.int64)
143-
144126
experts = CutlassExpertsFp8(num_local_experts,
145127
out_dtype,
146128
per_act_token,
147129
per_out_ch,
148-
ab_strides1,
149-
ab_strides2,
150-
c_strides1,
151-
c_strides2,
152130
num_dispatchers=num_dispatchers,
153131
use_batched_format=True)
154132

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
MoEPrepareAndFinalizeNoEP)
1414
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1515
TopKWeightAndReduceDelegate)
16-
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
16+
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
17+
_fp8_quantize,
1718
_resize_cache,
1819
extract_required_args)
1920
from vllm.scalar_type import scalar_types
@@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
3435
w2_scale: Optional[torch.Tensor],
3536
a1q_scale: Optional[torch.Tensor],
3637
a2_scale: Optional[torch.Tensor],
37-
ab_strides1: torch.Tensor,
38-
ab_strides2: torch.Tensor,
39-
c_strides1: torch.Tensor,
40-
c_strides2: torch.Tensor,
4138
workspace13: torch.Tensor,
4239
workspace2: torch.Tensor,
4340
expert_num_tokens: Optional[torch.Tensor],
@@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
156153
problem_sizes1, problem_sizes2, a_map,
157154
c_map, global_num_experts, N, K)
158155

159-
a1q = ops.shuffle_rows(a1q, a_map)
160-
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
161-
if per_act_token else a1q_scale)
156+
a1q = _fp8_perm(a1q, a_map)
157+
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
162158
expert_offsets = expert_offsets[:-1]
163159

160+
ab_strides1 = torch.full((w1.size(0), ),
161+
K,
162+
device=device,
163+
dtype=torch.int64)
164+
c_strides1 = torch.full((w1.size(0), ),
165+
2 * N,
166+
device=device,
167+
dtype=torch.int64)
168+
ab_strides2 = torch.full((w1.size(0), ),
169+
N,
170+
device=device,
171+
dtype=torch.int64)
172+
c_strides2 = torch.full((w1.size(0), ),
173+
K,
174+
device=device,
175+
dtype=torch.int64)
176+
164177
if use_batched_format:
165178
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
166179
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
@@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
197210
else:
198211
# We can't do this inplace because output may point to the same tensor
199212
# as c3.
200-
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
201-
non_blocking=True)
213+
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
202214

203215

204216
# TODO (bnell): split class batched vs. non-batched?
@@ -211,10 +223,6 @@ def __init__(
211223
out_dtype: Optional[torch.dtype],
212224
per_act_token_quant: bool,
213225
per_out_ch_quant: bool,
214-
ab_strides1: torch.Tensor,
215-
ab_strides2: torch.Tensor,
216-
c_strides1: torch.Tensor,
217-
c_strides2: torch.Tensor,
218226
block_shape: Optional[list[int]] = None,
219227
num_dispatchers: Optional[int] = None,
220228
use_batched_format: bool = False,
@@ -231,10 +239,6 @@ def __init__(
231239
self.max_experts_per_worker = max_experts_per_worker
232240
self.num_dispatchers = num_dispatchers
233241
self.out_dtype = out_dtype
234-
self.ab_strides1 = ab_strides1
235-
self.ab_strides2 = ab_strides2
236-
self.c_strides1 = c_strides1
237-
self.c_strides2 = c_strides2
238242
self.use_batched_format = use_batched_format
239243

240244
@property
@@ -314,8 +318,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
314318
run_cutlass_moe_fp8(
315319
output, hidden_states, w1, w2, topk_ids, activation_callable,
316320
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
317-
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
318-
self.c_strides2, workspace13, workspace2, expert_num_tokens,
321+
a2_scale, workspace13, workspace2, expert_num_tokens,
319322
self.out_dtype if self.out_dtype is not None else in_dtype,
320323
self.per_act_token_quant, self.per_out_ch_quant,
321324
self.use_batched_format)
@@ -329,10 +332,6 @@ def cutlass_moe_fp8(
329332
topk_ids: torch.Tensor,
330333
w1_scale: torch.Tensor,
331334
w2_scale: torch.Tensor,
332-
ab_strides1: torch.Tensor,
333-
ab_strides2: torch.Tensor,
334-
c_strides1: torch.Tensor,
335-
c_strides2: torch.Tensor,
336335
per_act_token: Optional[bool] = None,
337336
activation: str = "silu",
338337
a1_scale: Optional[torch.Tensor] = None,
@@ -360,17 +359,6 @@ def cutlass_moe_fp8(
360359
Shape: [num_experts] or [num_experts, 2N]
361360
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
362361
Shape: [num_experts] or [num_experts, K]
363-
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
364-
Shape: [num_experts]
365-
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
366-
Shape: [num_experts]
367-
- c_strides1 (torch.Tensor): The output strides for the first gemm.
368-
Shape: [num_experts]
369-
- c_strides2 (torch.Tensor): The output strides for the second gemm.
370-
Shape: [num_experts]
371-
- per_act_token (Optional[bool]): Whether the scale is per-token or
372-
per-tensor.
373-
- activation (str): The activation function to use.
374362
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
375363
Shape: scalar or [M]
376364
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -403,10 +391,6 @@ def cutlass_moe_fp8(
403391
out_dtype=a.dtype,
404392
per_act_token_quant=per_act_token,
405393
per_out_ch_quant=per_out_ch,
406-
ab_strides1=ab_strides1,
407-
ab_strides2=ab_strides2,
408-
c_strides1=c_strides1,
409-
c_strides2=c_strides2,
410394
use_batched_format=False,
411395
),
412396
)

0 commit comments

Comments
 (0)