Skip to content

Commit b3a55fd

Browse files
committed
Apply quant layer norm fixes from vllm-project#27865, inv scale fix for int8
Signed-off-by: ElizaWszola <ewszola@redhat.com>
1 parent ea9f4db commit b3a55fd

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,14 @@ __global__ void rms_norm_per_block_quant_kernel_3(
130130
// RMS Norm + Quant
131131
int token_idx = blockIdx.x * hidden_size / group_size;
132132
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
133-
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
134-
auto token_group_idx = token_idx + i / group_size;
135-
token_scale[token_group_idx] = 1.0f / token_scale[token_group_idx];
136-
}
133+
// Don't invert token_scale here: do it inside the norm_and_quant kernel
134+
// We do it because particular elements of token_scale can be shared
135+
// between multiple threads, so this way, we avoid extra synchronization
136+
// overhead.
137137
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
138138
out, input, weight, rms[blockIdx.x], token_scale + token_idx,
139139
hidden_size, residual, group_size);
140140
} else {
141-
// FP8 - Do not invert s_token_scale for exact match with FBGemm
142141
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
143142
out, input, weight, rms[blockIdx.x], token_scale + token_idx,
144143
hidden_size, residual, group_size);

csrc/quantization/fused_kernels/layernorm_utils.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,11 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
148148
// Norm
149149
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
150150
// Quant
151-
auto scale_val = (group_size > 0 ? scale[i / group_size] : *scale);
151+
// If groupwise is_scale_inverted is true, so we invert the scale here.
152+
auto scale_val =
153+
(group_size > 0 ? (is_scale_inverted ? 1.0f / scale[i / group_size]
154+
: scale[i / group_size])
155+
: *scale);
152156
output[token_offset + i] =
153157
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
154158
}

csrc/quantization/w8a8/int8/scaled_quant.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/all.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include <cmath>
56

@@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
275276
int const num_tokens = input.numel() / hidden_size;
276277
dim3 const grid(num_tokens);
277278
dim3 const block(std::min(hidden_size, 256));
279+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
278280
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
279281
VLLM_DISPATCH_FLOATING_TYPES(
280282
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant(
306308
int const num_tokens = input.numel() / hidden_size;
307309
dim3 const grid(num_tokens);
308310
dim3 const block(std::min(hidden_size, 256));
311+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
309312
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
310313
VLLM_DISPATCH_FLOATING_TYPES(
311314
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {

tests/kernels/core/test_fused_quant_layernorm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
DTYPES = [torch.bfloat16, torch.float]
1919
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
20-
VEC_HIDDEN_SIZES = range(1024, 1030)
20+
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
2121
# Avoid combinatorial explosion with full Cartesian product
2222
NUM_TOKENS_HIDDEN_SIZES = [
2323
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
@@ -165,8 +165,8 @@ def test_rms_norm(
165165
# skip
166166
return
167167

168-
# blockwise baseline doesn't support scale_ub
169168
if group_size is not None and has_scale_ub:
169+
# blockwise baseline doesn't support scale_ub
170170
return
171171

172172
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
@@ -197,14 +197,31 @@ def test_rms_norm(
197197

198198
assert ref_out.dtype == quant_dtype
199199
assert ops_out.dtype == quant_dtype
200-
assert torch.allclose(ref_scales, ops_scales)
201200
if quant_dtype == torch.int8:
201+
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
202202
# big atol to account for round-off errors.
203203
assert torch.allclose(ref_out, ops_out, atol=1)
204204
else:
205-
assert torch.allclose(
206-
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
207-
)
205+
assert torch.allclose(ref_scales, ops_scales)
206+
a = ref_out.to(dtype=torch.float32)
207+
b = ops_out.to(dtype=torch.float32)
208+
ok = torch.allclose(a, b, atol=1e-6)
209+
if not ok:
210+
# fallback: compare dequantized values with relaxed tolerance
211+
if group_size is None:
212+
a_deq = a * ref_scales.view(-1, 1)
213+
b_deq = b * ops_scales.view(-1, 1)
214+
else:
215+
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
216+
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
217+
# NOTE: It is possible that some future test cases trigger this
218+
# max diff due to precision issues. If such an error is
219+
# encountered, it's recommended to inspect the differences between
220+
# all corresponding elements from each tensor (e.g. by looping over
221+
# them) and checking how many the max diff error shows up on (just
222+
# a few bad elements should still be considered acceptable).
223+
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
224+
assert ok
208225
if add_residual:
209226
assert torch.allclose(ref_residual, ops_residual)
210227

0 commit comments

Comments
 (0)