From 5899fbd6d34e2bee64dff031bd5ecd08043ef312 Mon Sep 17 00:00:00 2001 From: siyuanf Date: Fri, 8 Aug 2025 19:39:49 +0000 Subject: [PATCH 1/8] add tg-mxfp4-moe-test Signed-off-by: siyuanf --- tests/kernels/moe/test_mxfp4_moe.py | 375 +++++++++++++++++++++++++++- 1 file changed, 374 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 824b072a9f93..0f554827d6c6 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -4,11 +4,18 @@ import importlib import importlib.metadata from dataclasses import dataclass +from typing import Optional import pytest import torch +import torch.nn.functional as F +from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) from packaging import version +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts + QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') @@ -54,4 +61,370 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) - assert output \ No newline at end of file + assert output + + +def swiglu(x, alpha: float = 1.702, limit: Optional[float] = None): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + if limit is not None: + x_linear = x_linear.clamp(min=-limit, max=limit) + return out_glu * (x_linear + 1) + + +def compute_routing_renormalize( + router_logits: torch.Tensor, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + routing_weights, selected_experts = torch.topk(router_logits, + top_k, + dim=-1) + routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float) + return selected_experts, routing_weights + + +fp4_lookup_table = [ + 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +] + + +def mxfp4_dequantize(x, scale): + assert x.dtype == torch.uint8 + x = x.view(torch.uint8).to(torch.int32) + x_unpacked = torch.zeros(*x.shape[:-1], + x.shape[-1] * 2, + dtype=torch.int32, + device=x.device) + x_unpacked[..., 0::2].copy_(x & 0xF) + x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) + + x_float = torch.zeros(x_unpacked.shape, + dtype=torch.float32, + device=x.device) + for i, val in enumerate(fp4_lookup_table): + x_float[x_unpacked == i] = val + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def mxfp8_dequantize(x, scale): + assert x.dtype == torch.float8_e4m3fn + x_float = x.to(torch.float32) + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def reference_bf16_moe( + topk_ids, + topk_weights, + topk, + num_experts, + hidden_states, + hidden_states_scale, + w13, + w2, + w13_scale, + w2_scale, + act_type, +): + w13 = mxfp4_dequantize(w13, w13_scale).to(torch.bfloat16) + w2 = mxfp4_dequantize(w2, w2_scale).to(torch.bfloat16) + if act_type == 'mxfp8': + hidden_states = mxfp8_dequantize( + hidden_states, hidden_states_scale).to(torch.bfloat16) + else: + hidden_states = hidden_states.to(torch.bfloat16) + ref_result = fused_experts(hidden_states, + w13, + w2, + topk_weights, + topk_ids, + inplace=False, + activation="silu", + is_act_and_mul=True, + global_num_experts=num_experts, + expert_map=None, + w1_scale=None, + w2_scale=None, + w1_zp=None, + w2_zp=None, + a1_scale=None, + a2_scale=None, + block_shape=None) + return ref_result + + +def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13_weight, + w13_weight_scale, + w2_weight, + w2_weight_scale, + act_type, +) -> torch.Tensor: + sf_block_size = 32 + assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2) + assert (w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size) + assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2) + assert (w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + w13_weight_scale_ = w13_weight_scale.clone() + w13_weight_ = w13_weight.clone() + w13_weight[:, :intermediate_size, :].copy_( + w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_( + w13_weight_[:, :intermediate_size, :]) + w13_weight_scale[:, :intermediate_size, :].copy_( + w13_weight_scale_[:, intermediate_size:, :]) + w13_weight_scale[:, intermediate_size:, :].copy_( + w13_weight_scale_[:, :intermediate_size, :]) + + # Interleave the weights and scaling factors for activation + w13_weight_interleaved = [] + w13_weight_scale_interleaved = [] + for i in range(num_experts): + w13_weight_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) + w13_weight_scale_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + w13_weight = torch.stack(w13_weight_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) + w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 32) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = torch.stack(gemm1_scales_mxfp4_shuffled).reshape( + num_experts, 2 * intermediate_size, + hidden_size // sf_block_size).view(torch.float8_e4m3fn) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + num_experts, hidden_size, + intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + + tg_result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale, + gemm2_bias=None, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), + routing_method_type=1, + do_finalize=True)[0] + return tg_result + + +def check_accuracy(a, b, atol, rtol, percent): + """Allow a mismatch percentage of 1 - percent.""" + if torch.any(torch.isnan(a)): + raise Exception("NaN in reference output") + if torch.any(torch.isnan(b)): + raise Exception("NaN in actual output") + if torch.any(torch.isinf(a)): + raise Exception("Inf in reference output") + if torch.any(torch.isinf(b)): + raise Exception("Inf in actual output") + assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " + f"(threshold: {1-percent:.4f})") + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(4096, 4096)]) +@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +def test_trtllm_gen_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + act_type: str, +): + seed = 42 + torch.manual_seed(seed) + hidden_states = torch.randn(num_tokens, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16) + w13 = (torch.randn(num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16)) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16)) + router_logits = torch.rand(num_tokens, num_experts, + dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize(w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, hidden_size // 32) + w2, w2_scale = fp4_quantize(w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 32) + if act_type == 'mxfp8': + hidden_states, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False) + hidden_states_scale = hidden_states_scale.view( + torch.float8_e4m3fn).reshape(-1) + else: + hidden_states_scale = None + + # reference result + topk_ids, topk_weights = compute_routing_renormalize(router_logits, topk) + ref_result = reference_bf16_moe( + topk_ids, + topk_weights, + topk, + num_experts, + hidden_states, + hidden_states_scale, + w13, + w2, + w13_scale, + w2_scale, + act_type, + ) + + # trtllm-gen result + tg_result = tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + w2, + w2_scale, + act_type, + ) + + # relatively loose accuracy check since the mxfp4 quantization is less accurate + # note that a few tests still fail due to accuracy issues + check_accuracy(ref_result, tg_result, atol=0, rtol=0.35, percent=0.9) + + +if __name__ == "__main__": + torch.set_printoptions(threshold=1000, sci_mode=False, precision=3) + test_trtllm_gen_mxfp4_fused_moe( + topk=4, + num_experts=32, + num_tokens=1024, + intermediate_size=4096, + hidden_size=4096, + act_type='mxfp8', + ) From 0cc4ec0870221054a832c77fa72aba8a3f2053ba Mon Sep 17 00:00:00 2001 From: siyuanf Date: Fri, 8 Aug 2025 22:13:18 +0000 Subject: [PATCH 2/8] add alpha and limit Signed-off-by: siyuanf --- tests/kernels/moe/test_mxfp4_moe.py | 131 ++++++++++++++-------------- 1 file changed, 66 insertions(+), 65 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 0f554827d6c6..d2f0569be74d 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -75,16 +75,6 @@ def swiglu(x, alpha: float = 1.702, limit: Optional[float] = None): return out_glu * (x_linear + 1) -def compute_routing_renormalize( - router_logits: torch.Tensor, - top_k: int) -> tuple[torch.Tensor, torch.Tensor]: - routing_weights, selected_experts = torch.topk(router_logits, - top_k, - dim=-1) - routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float) - return selected_experts, routing_weights - - fp4_lookup_table = [ 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 ] @@ -126,44 +116,34 @@ def mxfp8_dequantize(x, scale): return x_float * scale -def reference_bf16_moe( - topk_ids, - topk_weights, +def reference_moe( + roouting_logits, topk, num_experts, hidden_states, - hidden_states_scale, w13, w2, - w13_scale, - w2_scale, - act_type, + alpha, + limit, ): - w13 = mxfp4_dequantize(w13, w13_scale).to(torch.bfloat16) - w2 = mxfp4_dequantize(w2, w2_scale).to(torch.bfloat16) - if act_type == 'mxfp8': - hidden_states = mxfp8_dequantize( - hidden_states, hidden_states_scale).to(torch.bfloat16) - else: - hidden_states = hidden_states.to(torch.bfloat16) - ref_result = fused_experts(hidden_states, - w13, - w2, - topk_weights, - topk_ids, - inplace=False, - activation="silu", - is_act_and_mul=True, - global_num_experts=num_experts, - expert_map=None, - w1_scale=None, - w2_scale=None, - w1_zp=None, - w2_zp=None, - a1_scale=None, - a2_scale=None, - block_shape=None) - return ref_result + experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1).to(hidden_states.dtype) + expert_indices = experts.indices + t = hidden_states.clone() + # MLP #1 + mlp1_weight = w13[expert_indices, ...] + # mlp1_bias = w1_bias[topk_ids, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) # + mlp1_bias + t = swiglu(t, alpha=alpha, limit=limit) + + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + # mlp2_bias = w2_bias[topk_ids, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) # + mlp2_bias + + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + return t def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): @@ -203,6 +183,8 @@ def tg_mxfp4_moe( w2_weight, w2_weight_scale, act_type, + alpha, + limit, ) -> torch.Tensor: sf_block_size = 32 assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts @@ -282,9 +264,9 @@ def tg_mxfp4_moe( gemm1_weights=w13_weight, gemm1_weights_scale=w13_weight_scale, gemm1_bias=None, - gemm1_alpha=None, + gemm1_alpha=alpha, gemm1_beta=None, - gemm1_clamp_limit=None, + gemm1_clamp_limit=limit, gemm2_weights=w2_weight, gemm2_weights_scale=w2_weight_scale, gemm2_bias=None, @@ -300,7 +282,7 @@ def tg_mxfp4_moe( local_num_experts=num_experts, routed_scaling_factor=None, tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), - routing_method_type=1, + routing_method_type=1, # renormalize do_finalize=True)[0] return tg_result @@ -329,8 +311,9 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32, 128]) -@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) -@pytest.mark.parametrize("intermediate_size,hidden_size", [(4096, 4096)]) +@pytest.mark.parametrize("num_tokens", [1, 64]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,limit", [(1.0, None), (1.702, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) def test_trtllm_gen_mxfp4_fused_moe( topk: int, @@ -338,6 +321,8 @@ def test_trtllm_gen_mxfp4_fused_moe( num_tokens: int, intermediate_size: int, hidden_size: int, + alpha: Optional[float], + limit: Optional[float], act_type: str, ): seed = 42 @@ -382,22 +367,37 @@ def test_trtllm_gen_mxfp4_fused_moe( hidden_states_scale = None # reference result - topk_ids, topk_weights = compute_routing_renormalize(router_logits, topk) - ref_result = reference_bf16_moe( - topk_ids, - topk_weights, - topk, - num_experts, - hidden_states, - hidden_states_scale, - w13, - w2, - w13_scale, - w2_scale, - act_type, - ) + ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) + w13_ref = mxfp4_dequantize(w13, w13_scale).to(torch.bfloat16) + w2_ref = mxfp4_dequantize(w2, w2_scale).to(torch.bfloat16) + if act_type == 'mxfp8': + hidden_states_ref = mxfp8_dequantize( + hidden_states, hidden_states_scale).to(torch.bfloat16) + else: + hidden_states_ref = hidden_states + # Process tokens in chunks of 32 to reduce memory usage + chunk_size = 32 + num_chunks = (num_tokens + chunk_size - 1) // chunk_size + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_result = reference_moe( + router_logits[start_idx:end_idx], + topk, + num_experts, + hidden_states_ref[start_idx:end_idx], + w13_ref, + w2_ref, + alpha=alpha, + limit=limit + ) + ref_result[start_idx:end_idx].copy_(chunk_result) # trtllm-gen result + if alpha is not None: + alpha = torch.full((num_experts,), alpha, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts,), limit, device=hidden_states.device) tg_result = tg_mxfp4_moe( router_logits, topk, @@ -411,11 +411,12 @@ def test_trtllm_gen_mxfp4_fused_moe( w2, w2_scale, act_type, + alpha=alpha, + limit=limit ) - # relatively loose accuracy check since the mxfp4 quantization is less accurate - # note that a few tests still fail due to accuracy issues - check_accuracy(ref_result, tg_result, atol=0, rtol=0.35, percent=0.9) + # relatively loose check since the mxfp4 quantization is less accurate + check_accuracy(ref_result, tg_result, atol=0, rtol=0.5, percent=0.9) if __name__ == "__main__": @@ -423,7 +424,7 @@ def test_trtllm_gen_mxfp4_fused_moe( test_trtllm_gen_mxfp4_fused_moe( topk=4, num_experts=32, - num_tokens=1024, + num_tokens=128, intermediate_size=4096, hidden_size=4096, act_type='mxfp8', From 09982321fa960caa168bbd76c0d037a6096b43b1 Mon Sep 17 00:00:00 2001 From: siyuanf Date: Fri, 8 Aug 2025 23:07:23 +0000 Subject: [PATCH 3/8] upd Signed-off-by: siyuanf --- tests/kernels/moe/test_mxfp4_moe.py | 63 +++++++++++++++-------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index d2f0569be74d..e1e16493e161 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -8,14 +8,11 @@ import pytest import torch -import torch.nn.functional as F from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) from packaging import version -from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts - QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') @@ -69,9 +66,8 @@ def swiglu(x, alpha: float = 1.702, limit: Optional[float] = None): x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: x_glu = x_glu.clamp(max=limit) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - if limit is not None: x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) return out_glu * (x_linear + 1) @@ -125,21 +121,27 @@ def reference_moe( w2, alpha, limit, + act_type, ): experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) - expert_weights = torch.nn.functional.softmax(experts.values, dim=1).to(hidden_states.dtype) + expert_weights = torch.nn.functional.softmax(experts.values, + dim=1).to(hidden_states.dtype) expert_indices = experts.indices t = hidden_states.clone() # MLP #1 mlp1_weight = w13[expert_indices, ...] # mlp1_bias = w1_bias[topk_ids, ...] - t = torch.einsum("beck,bk->bec", mlp1_weight, t) # + mlp1_bias + t = torch.einsum("beck,bk->bec", mlp1_weight, t) # + mlp1_bias t = swiglu(t, alpha=alpha, limit=limit) + if act_type == 'mxfp8': + t_quantized, t_scale = mxfp8_quantize(t, is_sf_swizzled_layout=False) + t = mxfp8_dequantize(t_quantized, t_scale).to(torch.bfloat16) + # MLP #2 mlp2_weight = w2[expert_indices, ...] # mlp2_bias = w2_bias[topk_ids, ...] - t = torch.einsum("beck,bek->bec", mlp2_weight, t) # + mlp2_bias + t = torch.einsum("beck,bek->bec", mlp2_weight, t) # + mlp2_bias # Weighted sum of experts t = torch.einsum("bec,be->bc", t, expert_weights) @@ -282,7 +284,7 @@ def tg_mxfp4_moe( local_num_experts=num_experts, routed_scaling_factor=None, tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), - routing_method_type=1, # renormalize + routing_method_type=1, # renormalize do_finalize=True)[0] return tg_result @@ -311,7 +313,7 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32, 128]) -@pytest.mark.parametrize("num_tokens", [1, 64]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) @pytest.mark.parametrize("alpha,limit", [(1.0, None), (1.702, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) @@ -388,35 +390,34 @@ def test_trtllm_gen_mxfp4_fused_moe( hidden_states_ref[start_idx:end_idx], w13_ref, w2_ref, - alpha=alpha, - limit=limit + alpha, + limit, + act_type, ) ref_result[start_idx:end_idx].copy_(chunk_result) # trtllm-gen result if alpha is not None: - alpha = torch.full((num_experts,), alpha, device=hidden_states.device) + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) if limit is not None: - limit = torch.full((num_experts,), limit, device=hidden_states.device) - tg_result = tg_mxfp4_moe( - router_logits, - topk, - num_experts, - intermediate_size, - hidden_size, - hidden_states, - hidden_states_scale, - w13, - w13_scale, - w2, - w2_scale, - act_type, - alpha=alpha, - limit=limit - ) + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + tg_result = tg_mxfp4_moe(router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + w2, + w2_scale, + act_type, + alpha=alpha, + limit=limit) # relatively loose check since the mxfp4 quantization is less accurate - check_accuracy(ref_result, tg_result, atol=0, rtol=0.5, percent=0.9) + check_accuracy(ref_result, tg_result, atol=0, rtol=0.8, percent=0.8) if __name__ == "__main__": From 311bacec1c5539de2a5809ea3a23a2d8d4c50bf4 Mon Sep 17 00:00:00 2001 From: siyuanf Date: Sat, 9 Aug 2025 00:04:22 +0000 Subject: [PATCH 4/8] add bias Signed-off-by: siyuanf --- tests/kernels/moe/test_mxfp4_moe.py | 93 ++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index e1e16493e161..209fa6f1c0b4 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -118,7 +118,9 @@ def reference_moe( num_experts, hidden_states, w13, + bias13, w2, + bias2, alpha, limit, act_type, @@ -130,8 +132,8 @@ def reference_moe( t = hidden_states.clone() # MLP #1 mlp1_weight = w13[expert_indices, ...] - # mlp1_bias = w1_bias[topk_ids, ...] - t = torch.einsum("beck,bk->bec", mlp1_weight, t) # + mlp1_bias + mlp1_bias = bias13[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias t = swiglu(t, alpha=alpha, limit=limit) if act_type == 'mxfp8': @@ -140,8 +142,8 @@ def reference_moe( # MLP #2 mlp2_weight = w2[expert_indices, ...] - # mlp2_bias = w2_bias[topk_ids, ...] - t = torch.einsum("beck,bek->bec", mlp2_weight, t) # + mlp2_bias + mlp2_bias = bias2[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias # Weighted sum of experts t = torch.einsum("bec,be->bc", t, expert_weights) @@ -182,8 +184,10 @@ def tg_mxfp4_moe( hidden_states_scale, w13_weight, w13_weight_scale, + w13_bias, w2_weight, w2_weight_scale, + w2_bias, act_type, alpha, limit, @@ -202,11 +206,16 @@ def tg_mxfp4_moe( assert (w2_weight_scale.dim() == 3 and w2_weight_scale.shape[1] == hidden_size and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) + assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2) + assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size) # Swap w1 and w3 as the defenition of # swiglu is different in the trtllm-gen w13_weight_scale_ = w13_weight_scale.clone() w13_weight_ = w13_weight.clone() + w13_bias_ = w13_bias.clone() w13_weight[:, :intermediate_size, :].copy_( w13_weight_[:, intermediate_size:, :]) w13_weight[:, intermediate_size:, :].copy_( @@ -215,48 +224,64 @@ def tg_mxfp4_moe( w13_weight_scale_[:, intermediate_size:, :]) w13_weight_scale[:, intermediate_size:, :].copy_( w13_weight_scale_[:, :intermediate_size, :]) + w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) + w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) # Interleave the weights and scaling factors for activation w13_weight_interleaved = [] w13_weight_scale_interleaved = [] + w13_bias_interleaved = [] for i in range(num_experts): w13_weight_interleaved.append( reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) w13_weight_scale_interleaved.append( reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + w13_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, + 1))) w13_weight = torch.stack(w13_weight_interleaved).reshape( num_experts, 2 * intermediate_size, hidden_size // 2) w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( num_experts, 2 * intermediate_size, hidden_size // 32) + w13_bias = torch.stack(w13_bias_interleaved).reshape( + num_experts, 2 * intermediate_size) # Shuffle weights and scaling factors for transposed mma output - gemm1_weights_mxfp4_shuffled = [] - gemm1_scales_mxfp4_shuffled = [] - gemm2_weights_mxfp4_shuffled = [] - gemm2_scales_mxfp4_shuffled = [] + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(num_experts): - gemm1_weights_mxfp4_shuffled.append( + gemm1_weights_shuffled.append( shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm1_scales_mxfp4_shuffled.append( + gemm1_scales_shuffled.append( shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), epilogue_tile_m)) - gemm2_weights_mxfp4_shuffled.append( + gemm2_weights_shuffled.append( shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm2_scales_mxfp4_shuffled.append( + gemm2_scales_shuffled.append( shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) - w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) - w13_weight_scale = torch.stack(gemm1_scales_mxfp4_shuffled).reshape( + w13_weight = torch.stack(gemm1_weights_shuffled) + w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( num_experts, 2 * intermediate_size, hidden_size // sf_block_size).view(torch.float8_e4m3fn) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) - w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) - w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + w2_weight = torch.stack(gemm2_weights_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( num_experts, hidden_size, intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) tg_result = trtllm_fp4_block_scale_moe( routing_logits=router_logits.to(torch.bfloat16), @@ -265,13 +290,13 @@ def tg_mxfp4_moe( hidden_states_scale=hidden_states_scale, gemm1_weights=w13_weight, gemm1_weights_scale=w13_weight_scale, - gemm1_bias=None, + gemm1_bias=w13_bias, gemm1_alpha=alpha, gemm1_beta=None, gemm1_clamp_limit=limit, gemm2_weights=w2_weight, gemm2_weights_scale=w2_weight_scale, - gemm2_bias=None, + gemm2_bias=w2_bias, output1_scale_scalar=None, output1_scale_gate_scalar=None, output2_scale_scalar=None, @@ -343,6 +368,9 @@ def test_trtllm_gen_mxfp4_fused_moe( intermediate_size, device="cuda:0", dtype=torch.bfloat16)) + bias13 = torch.randn(num_experts, intermediate_size * 2, + device="cuda:0") * 10 + bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda() @@ -370,13 +398,16 @@ def test_trtllm_gen_mxfp4_fused_moe( # reference result ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) - w13_ref = mxfp4_dequantize(w13, w13_scale).to(torch.bfloat16) - w2_ref = mxfp4_dequantize(w2, w2_scale).to(torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), + w13_scale.clone()).to(torch.bfloat16) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()).to(torch.bfloat16) + bias13_ref = bias13.to(torch.bfloat16) + bias2_ref = bias2.to(torch.bfloat16) if act_type == 'mxfp8': hidden_states_ref = mxfp8_dequantize( hidden_states, hidden_states_scale).to(torch.bfloat16) else: - hidden_states_ref = hidden_states + hidden_states_ref = hidden_states.clone() # Process tokens in chunks of 32 to reduce memory usage chunk_size = 32 num_chunks = (num_tokens + chunk_size - 1) // chunk_size @@ -389,7 +420,9 @@ def test_trtllm_gen_mxfp4_fused_moe( num_experts, hidden_states_ref[start_idx:end_idx], w13_ref, + bias13_ref, w2_ref, + bias2_ref, alpha, limit, act_type, @@ -410,8 +443,10 @@ def test_trtllm_gen_mxfp4_fused_moe( hidden_states_scale, w13, w13_scale, + bias13, w2, w2_scale, + bias2, act_type, alpha=alpha, limit=limit) @@ -422,11 +457,11 @@ def test_trtllm_gen_mxfp4_fused_moe( if __name__ == "__main__": torch.set_printoptions(threshold=1000, sci_mode=False, precision=3) - test_trtllm_gen_mxfp4_fused_moe( - topk=4, - num_experts=32, - num_tokens=128, - intermediate_size=4096, - hidden_size=4096, - act_type='mxfp8', - ) + test_trtllm_gen_mxfp4_fused_moe(topk=4, + num_experts=32, + num_tokens=128, + intermediate_size=4096, + hidden_size=4096, + act_type='mxfp8', + alpha=1.0, + limit=None) From 62230f07a4c50a6d37f25bd8a7e736482d25fb2c Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Mon, 11 Aug 2025 14:32:30 -0700 Subject: [PATCH 5/8] add beta Signed-off-by: Siyuan Fu --- tests/kernels/moe/test_mxfp4_moe.py | 69 ++++++++++++++--------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 209fa6f1c0b4..fc81384a3e06 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -61,18 +61,21 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): assert output -def swiglu(x, alpha: float = 1.702, limit: Optional[float] = None): +def swiglu(x, + alpha: float = 1.702, + beta: float = 1.0, + limit: Optional[float] = None): # Note we add an extra bias of 1 to the linear layer x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - return out_glu * (x_linear + 1) + return out_glu * (x_linear + beta) fp4_lookup_table = [ - 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 + 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 ] @@ -122,32 +125,33 @@ def reference_moe( w2, bias2, alpha, + beta, limit, act_type, ): + # renormalize routing experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) - expert_weights = torch.nn.functional.softmax(experts.values, - dim=1).to(hidden_states.dtype) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) expert_indices = experts.indices t = hidden_states.clone() # MLP #1 mlp1_weight = w13[expert_indices, ...] mlp1_bias = bias13[expert_indices, ...] t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias - t = swiglu(t, alpha=alpha, limit=limit) + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) if act_type == 'mxfp8': - t_quantized, t_scale = mxfp8_quantize(t, is_sf_swizzled_layout=False) - t = mxfp8_dequantize(t_quantized, t_scale).to(torch.bfloat16) - + t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), + is_sf_swizzled_layout=False) + t = mxfp8_dequantize(t_quantized, t_scale) # MLP #2 mlp2_weight = w2[expert_indices, ...] mlp2_bias = bias2[expert_indices, ...] t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias - # Weighted sum of experts t = torch.einsum("bec,be->bc", t, expert_weights) - return t + assert t.shape == hidden_states.shape + return t.to(torch.bfloat16) def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): @@ -190,6 +194,7 @@ def tg_mxfp4_moe( w2_bias, act_type, alpha, + beta, limit, ) -> torch.Tensor: sf_block_size = 32 @@ -292,7 +297,7 @@ def tg_mxfp4_moe( gemm1_weights_scale=w13_weight_scale, gemm1_bias=w13_bias, gemm1_alpha=alpha, - gemm1_beta=None, + gemm1_beta=beta, gemm1_clamp_limit=limit, gemm2_weights=w2_weight, gemm2_weights_scale=w2_weight_scale, @@ -340,7 +345,8 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("num_experts", [32, 128]) @pytest.mark.parametrize("num_tokens", [1, 128, 1024]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) -@pytest.mark.parametrize("alpha,limit", [(1.0, None), (1.702, 7.0)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) def test_trtllm_gen_mxfp4_fused_moe( topk: int, @@ -348,7 +354,8 @@ def test_trtllm_gen_mxfp4_fused_moe( num_tokens: int, intermediate_size: int, hidden_size: int, - alpha: Optional[float], + alpha: float, + beta: float, limit: Optional[float], act_type: str, ): @@ -398,16 +405,15 @@ def test_trtllm_gen_mxfp4_fused_moe( # reference result ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) - w13_ref = mxfp4_dequantize(w13.clone(), - w13_scale.clone()).to(torch.bfloat16) - w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()).to(torch.bfloat16) - bias13_ref = bias13.to(torch.bfloat16) - bias2_ref = bias2.to(torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) + bias13_ref = bias13 + bias2_ref = bias2 if act_type == 'mxfp8': hidden_states_ref = mxfp8_dequantize( - hidden_states, hidden_states_scale).to(torch.bfloat16) + hidden_states, hidden_states_scale).to(torch.float32) else: - hidden_states_ref = hidden_states.clone() + hidden_states_ref = hidden_states.to(torch.float32) # Process tokens in chunks of 32 to reduce memory usage chunk_size = 32 num_chunks = (num_tokens + chunk_size - 1) // chunk_size @@ -415,7 +421,7 @@ def test_trtllm_gen_mxfp4_fused_moe( start_idx = i * chunk_size end_idx = min(start_idx + chunk_size, num_tokens) chunk_result = reference_moe( - router_logits[start_idx:end_idx], + router_logits[start_idx:end_idx].to(torch.float32), topk, num_experts, hidden_states_ref[start_idx:end_idx], @@ -424,6 +430,7 @@ def test_trtllm_gen_mxfp4_fused_moe( w2_ref, bias2_ref, alpha, + beta, limit, act_type, ) @@ -434,6 +441,8 @@ def test_trtllm_gen_mxfp4_fused_moe( alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) if limit is not None: limit = torch.full((num_experts, ), limit, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) tg_result = tg_mxfp4_moe(router_logits, topk, num_experts, @@ -449,19 +458,7 @@ def test_trtllm_gen_mxfp4_fused_moe( bias2, act_type, alpha=alpha, + beta=beta, limit=limit) - # relatively loose check since the mxfp4 quantization is less accurate - check_accuracy(ref_result, tg_result, atol=0, rtol=0.8, percent=0.8) - - -if __name__ == "__main__": - torch.set_printoptions(threshold=1000, sci_mode=False, precision=3) - test_trtllm_gen_mxfp4_fused_moe(topk=4, - num_experts=32, - num_tokens=128, - intermediate_size=4096, - hidden_size=4096, - act_type='mxfp8', - alpha=1.0, - limit=None) + check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) From 11fd35209985e3c7a75a37a135e712a7f7908ede Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Mon, 11 Aug 2025 15:03:15 -0700 Subject: [PATCH 6/8] address comments Signed-off-by: Siyuan Fu --- tests/kernels/moe/test_mxfp4_moe.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index fc81384a3e06..939d126ad86d 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -8,15 +8,22 @@ import pytest import torch -from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, - reorder_rows_for_gated_act_gemm, shuffle_matrix_a, - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) from packaging import version QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) +TRTLLM_GEN_MXFP4_AVAILABLE = torch.cuda.is_available( +) and device_props.major == 10 and device_props.minor == 0 + +if TRTLLM_GEN_MXFP4_AVAILABLE: + from flashinfer import (fp4_quantize, mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + @dataclass class ModelCase: @@ -359,6 +366,9 @@ def test_trtllm_gen_mxfp4_fused_moe( limit: Optional[float], act_type: str, ): + if not TRTLLM_GEN_MXFP4_AVAILABLE: + pytest.skip( + "This test requires nvidia gpu and compute capability sm100") seed = 42 torch.manual_seed(seed) hidden_states = torch.randn(num_tokens, From 8f8da5d5fbb1000308499a4b890e1c0678b5870e Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Mon, 11 Aug 2025 15:06:37 -0700 Subject: [PATCH 7/8] minor Signed-off-by: Siyuan Fu --- tests/kernels/moe/test_mxfp4_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 939d126ad86d..f9ce3ad46ab0 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -355,6 +355,9 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP4_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test") def test_trtllm_gen_mxfp4_fused_moe( topk: int, num_experts: int, @@ -366,9 +369,6 @@ def test_trtllm_gen_mxfp4_fused_moe( limit: Optional[float], act_type: str, ): - if not TRTLLM_GEN_MXFP4_AVAILABLE: - pytest.skip( - "This test requires nvidia gpu and compute capability sm100") seed = 42 torch.manual_seed(seed) hidden_states = torch.randn(num_tokens, From 25ea8d485bacf48f537cd749ba9bda906380c2d3 Mon Sep 17 00:00:00 2001 From: Siyuan Fu Date: Mon, 11 Aug 2025 18:59:56 -0700 Subject: [PATCH 8/8] address comments Signed-off-by: Siyuan Fu --- .buildkite/test-pipeline.yaml | 1 + tests/kernels/moe/test_mxfp4_moe.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e139c6b30586..f048a0fdc644 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -671,6 +671,7 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index f9ce3ad46ab0..7bd1ffce58e9 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -10,13 +10,14 @@ import torch from packaging import version +from vllm.platforms import current_platform + QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') -device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) -TRTLLM_GEN_MXFP4_AVAILABLE = torch.cuda.is_available( -) and device_props.major == 10 and device_props.minor == 0 +TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( +) and current_platform.is_device_capability(100) if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import (fp4_quantize, mxfp8_quantize,