Skip to content

Commit 78f7058

Browse files
pyjhzwhfacebook-github-bot
authored andcommitted
fix the scaled input issue (#4884)
Summary: - When calculating the scaled_A for a given input A, it does `scaled_A = A * 6.0 / local_amax`. However, it should be `scaled_A = A * global_scale / fp8(local_amax / 6.0 * global_scale)` - Use fp64 precision for global scaling factor (following nvidia's fake quantization of nvfp4 D76363519) Output of numerics_bench output_abs/rel_err_bf16: the average absolute/relative error of gemm output compared to bf16 gemm output_abs/rel_err_mvfp4: the average absolute/relative error of gemm output compared to bf16 gemm Before the diff the relative gemm differnece over fake quant is 90%; After the fix it is 9%, Before the diff: > I0908 183717.805 numerics_bench.py:279] Numeric metrics for native_nvfp4 nvfp4 symm,fp8,amax,none,tensorwise,e2m1,nearest,1x16,0 I0908 183717.806 numerics_bench.py:113] runtime: 0.404 ms. I0908 183717.806 numerics_bench.py:114] TFLOPS: 1147.215. I0908 183717.806 numerics_bench.py:115] output_abs_err_bf16: 0.008. I0908 183717.806 numerics_bench.py:116] output_rel_err_bf16: 1.312. I0908 183717.806 numerics_bench.py:118] output_abs_err_nvfp4: 0.004. I0908 183717.806 numerics_bench.py:121] output_rel_err_nvfp4: 0.902. After the diff: > I0908 182556.008 numerics_bench.py:279] Numeric metrics for native_nvfp4 nvfp4 symm,fp8,amax,none,tensorwise,e2m1,nearest,1x16,0 I0908 182556.008 numerics_bench.py:113] runtime: 0.400 ms. I0908 182556.008 numerics_bench.py:114] TFLOPS: 1160.963. I0908 182556.008 numerics_bench.py:115] output_abs_err_bf16: 0.007. I0908 182556.008 numerics_bench.py:116] output_rel_err_bf16: 1.273. I0908 182556.009 numerics_bench.py:118] output_abs_err_nvfp4: 0.000. I0908 182556.009 numerics_bench.py:121] output_rel_err_nvfp4: 0.092. Pull Request resolved: #4884 Reviewed By: ghjeong12 Differential Revision: D82147819 Pulled By: pyjhzwh fbshipit-source-id: fbbfbdb529526e63a57b9e0eed7ed8b5d9234593
1 parent 0988166 commit 78f7058

File tree

1 file changed

+40
-30
lines changed

1 file changed

+40
-30
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,13 +1347,14 @@ def _kernel_nvfp4_quantize(
13471347
group_max = tl.max(tl.abs(a_groups), axis=1).to(tl.float32)
13481348

13491349
# Next we scale A in preparation for quantization.
1350-
scale_ = group_max / 6.0 * input_global_scale
1350+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
13511351
# Prevent infinite values in log.
13521352
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
13531353

13541354
# Apply scale_ to input. We do this by broadcasting scale.
1355+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
13551356
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
1356-
6.0 / group_max, [GROUP_LOAD, 1]
1357+
input_global_scale / scale_, [GROUP_LOAD, 1]
13571358
)
13581359
# Reshape back to a flat array.
13591360
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -1417,7 +1418,7 @@ def _kernel_nvfp4_quantize(
14171418
)
14181419
tl.store(
14191420
scale + actual_offset,
1420-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
1421+
scale_.to(tl.uint8, bitcast=True),
14211422
# Prevent writing outside this chunk or the main array.
14221423
mask=(exp_offset < SCALE_SIZE)
14231424
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1))),
@@ -1694,13 +1695,14 @@ def _kernel_nvfp4_quantize_silu(
16941695
group_max = tl.max(tl.abs(a_groups), axis=1)
16951696

16961697
# Next we scale A in preparation for quantization.
1697-
scale_ = group_max / 6.0 * input_global_scale
1698+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
16981699
# Prevent infinite values in log.
16991700
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
17001701

17011702
# Apply scale_ to input. We do this by broadcasting scale.
1703+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
17021704
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
1703-
6.0 / group_max, [GROUP_LOAD, 1]
1705+
input_global_scale / scale_, [GROUP_LOAD, 1]
17041706
)
17051707
# Reshape back to a flat array.
17061708
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -1766,7 +1768,7 @@ def _kernel_nvfp4_quantize_silu(
17661768
)
17671769
tl.store(
17681770
scale + actual_offset,
1769-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
1771+
scale_.to(tl.uint8, bitcast=True),
17701772
# Prevent writing outside this chunk or the main array.
17711773
mask=(exp_offset < SCALE_SIZE)
17721774
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1))),
@@ -2053,13 +2055,14 @@ def _kernel_nvfp4_quantize_rms(
20532055
group_max = tl.max(tl.abs(a_groups), axis=1)
20542056

20552057
# Next we scale A in preparation for quantization.
2056-
scale_ = group_max / 6.0 * input_global_scale
2058+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
20572059
# Prevent infinite values in log.
20582060
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
20592061

20602062
# Apply scale_ to input. We do this by broadcasting scale.
2063+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
20612064
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
2062-
6.0 / group_max, [GROUP_LOAD, 1]
2065+
input_global_scale / scale_, [GROUP_LOAD, 1]
20632066
)
20642067
# Reshape back to a flat array.
20652068
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -2127,7 +2130,7 @@ def _kernel_nvfp4_quantize_rms(
21272130
)
21282131
tl.store(
21292132
scale + actual_offset,
2130-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
2133+
scale_.to(tl.uint8, bitcast=True),
21312134
# Prevent writing outside this chunk or the main array.
21322135
mask=(exp_offset < SCALE_SIZE)
21332136
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1))),
@@ -2415,13 +2418,14 @@ def _kernel_nvfp4_quantize_stacked(
24152418
group_max = tl.max(tl.abs(a_groups), axis=1).to(tl.float32)
24162419

24172420
# Next we scale A in preparation for quantization.
2418-
scale_ = group_max / 6.0 * input_global_scale
2421+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
24192422
# Prevent infinite values in log.
24202423
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
24212424

24222425
# Apply scale_ to input. We do this by broadcasting scale.
2426+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
24232427
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
2424-
6.0 / group_max, [GROUP_LOAD, 1]
2428+
input_global_scale / scale_, [GROUP_LOAD, 1]
24252429
)
24262430
# Reshape back to a flat array.
24272431
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -2489,7 +2493,7 @@ def _kernel_nvfp4_quantize_stacked(
24892493

24902494
tl.store(
24912495
scale + actual_scale_offset_permute,
2492-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
2496+
scale_.to(tl.uint8, bitcast=True),
24932497
# Prevent writing outside this chunk or the main array.
24942498
mask=(row_idx < M)
24952499
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1)))
@@ -3092,13 +3096,14 @@ def _kernel_nvfp4_quantize_stacked_silu(
30923096
group_max = tl.max(tl.abs(a_groups), axis=1).to(tl.float32)
30933097

30943098
# Next we scale A in preparation for quantization.
3095-
scale_ = group_max / 6.0 * input_global_scale
3099+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
30963100
# Prevent infinite values in log.
30973101
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
30983102

30993103
# Apply scale_ to input. We do this by broadcasting scale.
3104+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
31003105
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
3101-
6.0 / group_max, [GROUP_LOAD, 1]
3106+
input_global_scale / scale_, [GROUP_LOAD, 1]
31023107
)
31033108
# Reshape back to a flat array.
31043109
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -3166,7 +3171,7 @@ def _kernel_nvfp4_quantize_stacked_silu(
31663171

31673172
tl.store(
31683173
scale + actual_scale_offset_permute,
3169-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
3174+
scale_.to(tl.uint8, bitcast=True),
31703175
# Prevent writing outside this chunk or the main array.
31713176
mask=(row_idx < M)
31723177
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1)))
@@ -3384,13 +3389,14 @@ def _mega_fp4_quantize_kernel(
33843389
input_global_scale_tensor + tensor_idx, mask=tensor_idx_guard
33853390
)
33863391
# Next we scale A in preparation for quantization.
3387-
scale_ = group_max / 6.0 * input_global_scale
3392+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
33883393
# Prevent infinite values in log.
33893394
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
33903395

33913396
# Apply scale_ to input. We do this by broadcasting scale.
3397+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
33923398
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
3393-
6.0 / group_max, [GROUP_LOAD, 1]
3399+
input_global_scale / scale_, [GROUP_LOAD, 1]
33943400
)
33953401
# Reshape back to a flat array.
33963402
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -3458,7 +3464,7 @@ def _mega_fp4_quantize_kernel(
34583464

34593465
tl.store(
34603466
scale + actual_scale_offset_permute,
3461-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
3467+
scale_.to(tl.uint8, bitcast=True),
34623468
# Prevent writing outside this chunk or the main array.
34633469
mask=(row_idx < M)
34643470
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1)))
@@ -3654,13 +3660,14 @@ def _mega_fp4_quantize_kernel_with_tensor_idx(
36543660
input_global_scale_tensor + tensor_idx, mask=tensor_idx_guard
36553661
)
36563662
# Next we scale A in preparation for quantization.
3657-
scale_ = group_max / 6.0 * input_global_scale
3663+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
36583664
# Prevent infinite values in log.
36593665
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
36603666

36613667
# Apply scale_ to input. We do this by broadcasting scale.
3668+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
36623669
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
3663-
6.0 / group_max, [GROUP_LOAD, 1]
3670+
input_global_scale / scale_, [GROUP_LOAD, 1]
36643671
)
36653672
# Reshape back to a flat array.
36663673
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -3728,7 +3735,7 @@ def _mega_fp4_quantize_kernel_with_tensor_idx(
37283735

37293736
tl.store(
37303737
scale + actual_scale_offset_permute,
3731-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
3738+
scale_.to(tl.uint8, bitcast=True),
37323739
# Prevent writing outside this chunk or the main array.
37333740
mask=(row_idx < M)
37343741
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1)))
@@ -4238,13 +4245,14 @@ def _kernel_nvfp4_quantize_stacked_rms(
42384245
group_max = tl.max(tl.abs(a_groups), axis=1).to(tl.float32)
42394246

42404247
# Next we scale A in preparation for quantization.
4241-
scale_ = group_max / 6.0 * input_global_scale
4248+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
42424249
# Prevent infinite values in log.
42434250
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
42444251

42454252
# Apply scale_ to input. We do this by broadcasting scale.
4253+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
42464254
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
4247-
6.0 / group_max, [GROUP_LOAD, 1]
4255+
input_global_scale / scale_, [GROUP_LOAD, 1]
42484256
)
42494257
# Reshape back to a flat array.
42504258
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -4312,7 +4320,7 @@ def _kernel_nvfp4_quantize_stacked_rms(
43124320

43134321
tl.store(
43144322
scale + actual_scale_offset_permute,
4315-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
4323+
scale_.to(tl.uint8, bitcast=True),
43164324
# Prevent writing outside this chunk or the main array.
43174325
mask=(row_idx < M)
43184326
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1)))
@@ -4580,13 +4588,14 @@ def _mega_fp4_pack_kernel(
45804588
group_max = tl.max(tl.abs(a_groups), axis=1).to(tl.float32)
45814589

45824590
# Next we scale A in preparation for quantization.
4583-
scale_ = group_max / 6.0 * input_global_scale
4591+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
45844592
# Prevent infinite values in log.
45854593
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
45864594

45874595
# Apply scale_ to input. We do this by broadcasting scale.
4596+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
45884597
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
4589-
6.0 / group_max, [GROUP_LOAD, 1]
4598+
input_global_scale / scale_, [GROUP_LOAD, 1]
45904599
)
45914600
# Reshape back to a flat array.
45924601
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -4638,7 +4647,7 @@ def _mega_fp4_pack_kernel(
46384647

46394648
tl.store(
46404649
out + exp_offset,
4641-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
4650+
scale_.to(tl.uint8, bitcast=True),
46424651
# Prevent writing outside this chunk or the main array.
46434652
mask=(exp_offset < (SCALE_CHUNK_SIZE * (pid + 1) + SCALE_SHIFT))
46444653
& (exp_offset < SCALE_SIZE + SCALE_SHIFT),
@@ -4792,13 +4801,14 @@ def _mega_fp4_pack_kernel_per_tensor(
47924801
input_global_scale_tensor + tensor_idx, mask=tensor_idx_guard
47934802
)
47944803
# Next we scale A in preparation for quantization.
4795-
scale_ = group_max / 6.0 * input_global_scale
4804+
scale_ = (group_max / 6.0 * input_global_scale).to(tl.float8e4nv)
47964805
# Prevent infinite values in log.
47974806
group_max = tl.where(group_max == 0, BF16_MIN_NORMAL, group_max)
47984807

47994808
# Apply scale_ to input. We do this by broadcasting scale.
4809+
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
48004810
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
4801-
6.0 / group_max, [GROUP_LOAD, 1]
4811+
input_global_scale / scale_, [GROUP_LOAD, 1]
48024812
)
48034813
# Reshape back to a flat array.
48044814
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -4850,7 +4860,7 @@ def _mega_fp4_pack_kernel_per_tensor(
48504860

48514861
tl.store(
48524862
out + exp_offset,
4853-
scale_.to(tl.float8e4nv).to(tl.uint8, bitcast=True),
4863+
scale_.to(tl.uint8, bitcast=True),
48544864
# Prevent writing outside this chunk or the main array.
48554865
mask=(exp_offset < (SCALE_CHUNK_SIZE * (pid + 1) + SCALE_SHIFT))
48564866
& (exp_offset < (SCALE_SIZE + SCALE_SHIFT)),

0 commit comments

Comments
 (0)