-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
951dc4e
949835c
0f9914e
ff1424e
0aa017b
07793d4
e979c40
3d59aa9
9f16033
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, | |
'topk_ids': topk_ids, | ||
'w1_scale': moe_tensors.w1_scale, | ||
'w2_scale': moe_tensors.w2_scale, | ||
'ab_strides1': moe_tensors.ab_strides1, | ||
'ab_strides2': moe_tensors.ab_strides2, | ||
'c_strides1': moe_tensors.c_strides1, | ||
'c_strides2': moe_tensors.c_strides2, | ||
'per_act_token': per_act_token, | ||
'a1_scale': None #moe_tensors.a_scale | ||
} | ||
|
@@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8( | |
expert_map[start:end] = list(range(num_local_experts)) | ||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") | ||
|
||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) | ||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) | ||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) | ||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) | ||
Comment on lines
+446
to
+449
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we save on memory by using the same tensor for both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this should be possible. I'll push an update |
||
|
||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) | ||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, | ||
torch.float8_e4m3fn, | ||
|
@@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8( | |
func = lambda output: run_cutlass_moe_fp8( | ||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, | ||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, | ||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, | ||
per_act_token, per_out_channel, False) | ||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, | ||
workspace13, workspace2, None, mt.a.dtype, per_act_token, | ||
per_out_channel, False) | ||
|
||
workspace13.random_() | ||
output_random_workspace = torch.empty(output_shape, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,7 @@ | |
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig | ||
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( | ||
MoEPrepareAndFinalizeNoEP) | ||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, | ||
_fp8_quantize, | ||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, | ||
_resize_cache) | ||
from vllm.scalar_type import scalar_types | ||
|
||
|
@@ -32,6 +31,10 @@ def run_cutlass_moe_fp8( | |
w2_scale: Optional[torch.Tensor], | ||
a1q_scale: Optional[torch.Tensor], | ||
a2_scale: Optional[torch.Tensor], | ||
ab_strides1: torch.Tensor, | ||
ab_strides2: torch.Tensor, | ||
c_strides1: torch.Tensor, | ||
c_strides2: torch.Tensor, | ||
workspace13: torch.Tensor, | ||
workspace2: torch.Tensor, | ||
expert_num_tokens: Optional[torch.Tensor], | ||
|
@@ -150,27 +153,11 @@ def run_cutlass_moe_fp8( | |
problem_sizes1, problem_sizes2, a_map, | ||
c_map, global_num_experts, N, K) | ||
|
||
a1q = _fp8_perm(a1q, a_map) | ||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale | ||
a1q = ops.shuffle_rows(a1q, a_map) | ||
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map) | ||
if per_act_token else a1q_scale) | ||
Comment on lines
-155
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the speedup mentioned in this statement
come from being able to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It comes from using these custom kernels instead of the pytorch index-based function. The "slow" kernels are expected to only be run to shuffle scales. |
||
expert_offsets = expert_offsets[:-1] | ||
|
||
ab_strides1 = torch.full((w1.size(0), ), | ||
K, | ||
device=device, | ||
dtype=torch.int64) | ||
c_strides1 = torch.full((w1.size(0), ), | ||
2 * N, | ||
device=device, | ||
dtype=torch.int64) | ||
ab_strides2 = torch.full((w1.size(0), ), | ||
N, | ||
device=device, | ||
dtype=torch.int64) | ||
c_strides2 = torch.full((w1.size(0), ), | ||
K, | ||
device=device, | ||
dtype=torch.int64) | ||
|
||
if use_batched_format: | ||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) | ||
c2 = _resize_cache(workspace2, (local_E * padded_M, N)) | ||
|
@@ -207,7 +194,8 @@ def run_cutlass_moe_fp8( | |
else: | ||
# We can't do this inplace because output may point to the same tensor | ||
# as c3. | ||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) | ||
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K), | ||
non_blocking=True) | ||
|
||
|
||
# TODO (bnell): split class batched vs. non-batched? | ||
|
@@ -220,6 +208,10 @@ def __init__( | |
out_dtype: Optional[torch.dtype], | ||
per_act_token_quant: bool, | ||
per_out_ch_quant: bool, | ||
ab_strides1: torch.Tensor, | ||
ab_strides2: torch.Tensor, | ||
c_strides1: torch.Tensor, | ||
c_strides2: torch.Tensor, | ||
block_shape: Optional[list[int]] = None, | ||
num_dispatchers: Optional[int] = None, | ||
use_batched_format: bool = False, | ||
|
@@ -236,6 +228,10 @@ def __init__( | |
self.max_experts_per_worker = max_experts_per_worker | ||
self.num_dispatchers = num_dispatchers | ||
self.out_dtype = out_dtype | ||
self.ab_strides1 = ab_strides1 | ||
self.ab_strides2 = ab_strides2 | ||
self.c_strides1 = c_strides1 | ||
self.c_strides2 = c_strides2 | ||
self.use_batched_format = use_batched_format | ||
|
||
@property | ||
|
@@ -312,7 +308,8 @@ def apply( | |
run_cutlass_moe_fp8( | ||
output, hidden_states, w1, w2, topk_ids, activation_callable, | ||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, | ||
a2_scale, workspace13, workspace2, expert_num_tokens, | ||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, | ||
self.c_strides2, workspace13, workspace2, expert_num_tokens, | ||
self.out_dtype if self.out_dtype is not None else in_dtype, | ||
self.per_act_token_quant, self.per_out_ch_quant, | ||
self.use_batched_format) | ||
|
@@ -326,6 +323,10 @@ def cutlass_moe_fp8( | |
topk_ids: torch.Tensor, | ||
w1_scale: torch.Tensor, | ||
w2_scale: torch.Tensor, | ||
ab_strides1: torch.Tensor, | ||
ab_strides2: torch.Tensor, | ||
c_strides1: torch.Tensor, | ||
c_strides2: torch.Tensor, | ||
per_act_token: Optional[bool] = None, | ||
activation: str = "silu", | ||
a1_scale: Optional[torch.Tensor] = None, | ||
|
@@ -353,6 +354,17 @@ def cutlass_moe_fp8( | |
Shape: [num_experts] or [num_experts, 2N] | ||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. | ||
Shape: [num_experts] or [num_experts, K] | ||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. | ||
Shape: [num_experts] | ||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. | ||
Shape: [num_experts] | ||
- c_strides1 (torch.Tensor): The output strides for the first gemm. | ||
Shape: [num_experts] | ||
- c_strides2 (torch.Tensor): The output strides for the second gemm. | ||
Shape: [num_experts] | ||
- per_act_token (Optional[bool]): Whether the scale is per-token or | ||
per-tensor. | ||
- activation (str): The activation function to use. | ||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. | ||
Shape: scalar or [M] | ||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to | ||
|
@@ -385,6 +397,10 @@ def cutlass_moe_fp8( | |
out_dtype=a.dtype, | ||
per_act_token_quant=per_act_token, | ||
per_out_ch_quant=per_out_ch, | ||
ab_strides1=ab_strides1, | ||
ab_strides2=ab_strides2, | ||
c_strides1=c_strides1, | ||
c_strides2=c_strides2, | ||
use_batched_format=False, | ||
), | ||
) | ||
|
Uh oh!
There was an error while loading. Please reload this page.