From 512e00eca31ef89eafca8d38a170e978d2371bc9 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 31 Oct 2025 03:53:55 -0400 Subject: [PATCH 1/5] Fix test fused quant layernorm tests Signed-off-by: ElizaWszola --- tests/kernels/core/test_fused_quant_layernorm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 63b5a37d3c77..39f8ef0bf1b0 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant( ) else: assert quant_dtype == torch.int8 - torch_out, scales = ops.scaled_int8_quant(torch_out) + torch_out, scales, _ = ops.scaled_int8_quant(torch_out) return torch_out, scales, residual @@ -109,7 +109,7 @@ def ops_impl( @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) -@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("has_scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) @pytest.mark.parametrize("seed", SEEDS) @@ -119,7 +119,7 @@ def test_rms_norm( num_tokens: int, hidden_size: int, add_residual: bool, - scale_ub: bool, + has_scale_ub: bool, dtype: torch.dtype, quant_dtype: torch.dtype, seed: int, @@ -130,7 +130,7 @@ def test_rms_norm( torch.cuda.manual_seed(seed) torch.set_default_device(device) - if scale_ub is not None and quant_dtype != torch.float8_e4m3fn: + if has_scale_ub and quant_dtype != torch.float8_e4m3fn: # skip return @@ -143,9 +143,11 @@ def test_rms_norm( scale = 1 / (hidden_size) x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale residual = torch.randn_like(x) * scale if add_residual else None - if scale_ub is not None: + if has_scale_ub: rms_x, _ = ref_rms_norm(layer, x, residual) scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") + else: + scale_ub = None ref_out, ref_scales, ref_residual = ref_impl( layer, x, quant_dtype, residual, scale_ub From a618d91455ad06a21c60ec0f181413f878af66e7 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 3 Nov 2025 16:44:22 +0000 Subject: [PATCH 2/5] add fallback Signed-off-by: yewentao256 --- .../kernels/core/test_fused_quant_layernorm.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 39f8ef0bf1b0..9379fbc60c1b 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -11,7 +11,7 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -VEC_HIDDEN_SIZES = range(1024, 1030) +VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] # Avoid combinatorial explosion with full Cartesian product NUM_TOKENS_HIDDEN_SIZES = [ *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], @@ -158,14 +158,21 @@ def test_rms_norm( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype - assert torch.allclose(ref_scales, ops_scales) if quant_dtype == torch.int8: + assert torch.allclose(ref_scales, ops_scales, rtol=0.1, atol=1) # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: - assert torch.allclose( - ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) - ) + assert torch.allclose(ref_scales, ops_scales) + a = ref_out.to(dtype=torch.float32) + b = ops_out.to(dtype=torch.float32) + ok = torch.allclose(a, b) + if not ok: + # fallback: compare dequantized values with relaxed tolerance + a_deq = a * ref_scales.view(-1, 1) + b_deq = b * ops_scales.view(-1, 1) + ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2) + assert ok if add_residual: assert torch.allclose(ref_residual, ops_residual) From f3d4a9e05cb1dc39babfcaf1ce5ac2716be60195 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 3 Nov 2025 16:44:30 +0000 Subject: [PATCH 3/5] fix IMA issue Signed-off-by: yewentao256 --- csrc/quantization/w8a8/int8/scaled_quant.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu index 7fe9e96bfb01..be8ecfeacf8c 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant( int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { From a25bf16084587adaa65d4e4ae636f85417309f6a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 3 Nov 2025 13:03:07 -0500 Subject: [PATCH 4/5] Comment about fp8 precision Signed-off-by: ElizaWszola --- tests/kernels/core/test_fused_quant_layernorm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 9379fbc60c1b..5f788b504876 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -171,6 +171,12 @@ def test_rms_norm( # fallback: compare dequantized values with relaxed tolerance a_deq = a * ref_scales.view(-1, 1) b_deq = b * ops_scales.view(-1, 1) + # NOTE: It is possible that some future test cases trigger this + # max diff due to precision issues. If such an error is + # encountered, it's recommended to inspect the differences between + # all corresponding elements from each tensor (e.g. by looping over + # them) and checking how many the max diff error shows up on (just + # a few bad elements should still be considered acceptable). ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2) assert ok if add_residual: From 152e69a75fd4b481bf496a44a4be479449e4db2b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 4 Nov 2025 02:33:29 -0500 Subject: [PATCH 5/5] Lower tol on int8 scales Signed-off-by: ElizaWszola --- tests/kernels/core/test_fused_quant_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 5f788b504876..b5fc653ca735 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -159,7 +159,7 @@ def test_rms_norm( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype if quant_dtype == torch.int8: - assert torch.allclose(ref_scales, ops_scales, rtol=0.1, atol=1) + assert torch.allclose(ref_scales, ops_scales, atol=1e-6) # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: