@@ -1347,13 +1347,14 @@ def _kernel_nvfp4_quantize(
1347
1347
group_max = tl .max (tl .abs (a_groups ), axis = 1 ).to (tl .float32 )
1348
1348
1349
1349
# Next we scale A in preparation for quantization.
1350
- scale_ = group_max / 6.0 * input_global_scale
1350
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
1351
1351
# Prevent infinite values in log.
1352
1352
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
1353
1353
1354
1354
# Apply scale_ to input. We do this by broadcasting scale.
1355
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
1355
1356
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
1356
- 6.0 / group_max , [GROUP_LOAD , 1 ]
1357
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
1357
1358
)
1358
1359
# Reshape back to a flat array.
1359
1360
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -1417,7 +1418,7 @@ def _kernel_nvfp4_quantize(
1417
1418
)
1418
1419
tl .store (
1419
1420
scale + actual_offset ,
1420
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
1421
+ scale_ .to (tl .uint8 , bitcast = True ),
1421
1422
# Prevent writing outside this chunk or the main array.
1422
1423
mask = (exp_offset < SCALE_SIZE )
1423
1424
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ))),
@@ -1694,13 +1695,14 @@ def _kernel_nvfp4_quantize_silu(
1694
1695
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
1695
1696
1696
1697
# Next we scale A in preparation for quantization.
1697
- scale_ = group_max / 6.0 * input_global_scale
1698
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
1698
1699
# Prevent infinite values in log.
1699
1700
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
1700
1701
1701
1702
# Apply scale_ to input. We do this by broadcasting scale.
1703
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
1702
1704
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
1703
- 6.0 / group_max , [GROUP_LOAD , 1 ]
1705
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
1704
1706
)
1705
1707
# Reshape back to a flat array.
1706
1708
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -1766,7 +1768,7 @@ def _kernel_nvfp4_quantize_silu(
1766
1768
)
1767
1769
tl .store (
1768
1770
scale + actual_offset ,
1769
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
1771
+ scale_ .to (tl .uint8 , bitcast = True ),
1770
1772
# Prevent writing outside this chunk or the main array.
1771
1773
mask = (exp_offset < SCALE_SIZE )
1772
1774
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ))),
@@ -2053,13 +2055,14 @@ def _kernel_nvfp4_quantize_rms(
2053
2055
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
2054
2056
2055
2057
# Next we scale A in preparation for quantization.
2056
- scale_ = group_max / 6.0 * input_global_scale
2058
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
2057
2059
# Prevent infinite values in log.
2058
2060
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
2059
2061
2060
2062
# Apply scale_ to input. We do this by broadcasting scale.
2063
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
2061
2064
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
2062
- 6.0 / group_max , [GROUP_LOAD , 1 ]
2065
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
2063
2066
)
2064
2067
# Reshape back to a flat array.
2065
2068
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -2127,7 +2130,7 @@ def _kernel_nvfp4_quantize_rms(
2127
2130
)
2128
2131
tl .store (
2129
2132
scale + actual_offset ,
2130
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
2133
+ scale_ .to (tl .uint8 , bitcast = True ),
2131
2134
# Prevent writing outside this chunk or the main array.
2132
2135
mask = (exp_offset < SCALE_SIZE )
2133
2136
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ))),
@@ -2415,13 +2418,14 @@ def _kernel_nvfp4_quantize_stacked(
2415
2418
group_max = tl .max (tl .abs (a_groups ), axis = 1 ).to (tl .float32 )
2416
2419
2417
2420
# Next we scale A in preparation for quantization.
2418
- scale_ = group_max / 6.0 * input_global_scale
2421
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
2419
2422
# Prevent infinite values in log.
2420
2423
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
2421
2424
2422
2425
# Apply scale_ to input. We do this by broadcasting scale.
2426
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
2423
2427
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
2424
- 6.0 / group_max , [GROUP_LOAD , 1 ]
2428
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
2425
2429
)
2426
2430
# Reshape back to a flat array.
2427
2431
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -2489,7 +2493,7 @@ def _kernel_nvfp4_quantize_stacked(
2489
2493
2490
2494
tl .store (
2491
2495
scale + actual_scale_offset_permute ,
2492
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
2496
+ scale_ .to (tl .uint8 , bitcast = True ),
2493
2497
# Prevent writing outside this chunk or the main array.
2494
2498
mask = (row_idx < M )
2495
2499
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 )))
@@ -3092,13 +3096,14 @@ def _kernel_nvfp4_quantize_stacked_silu(
3092
3096
group_max = tl .max (tl .abs (a_groups ), axis = 1 ).to (tl .float32 )
3093
3097
3094
3098
# Next we scale A in preparation for quantization.
3095
- scale_ = group_max / 6.0 * input_global_scale
3099
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
3096
3100
# Prevent infinite values in log.
3097
3101
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
3098
3102
3099
3103
# Apply scale_ to input. We do this by broadcasting scale.
3104
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
3100
3105
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
3101
- 6.0 / group_max , [GROUP_LOAD , 1 ]
3106
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
3102
3107
)
3103
3108
# Reshape back to a flat array.
3104
3109
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -3166,7 +3171,7 @@ def _kernel_nvfp4_quantize_stacked_silu(
3166
3171
3167
3172
tl .store (
3168
3173
scale + actual_scale_offset_permute ,
3169
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
3174
+ scale_ .to (tl .uint8 , bitcast = True ),
3170
3175
# Prevent writing outside this chunk or the main array.
3171
3176
mask = (row_idx < M )
3172
3177
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 )))
@@ -3384,13 +3389,14 @@ def _mega_fp4_quantize_kernel(
3384
3389
input_global_scale_tensor + tensor_idx , mask = tensor_idx_guard
3385
3390
)
3386
3391
# Next we scale A in preparation for quantization.
3387
- scale_ = group_max / 6.0 * input_global_scale
3392
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
3388
3393
# Prevent infinite values in log.
3389
3394
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
3390
3395
3391
3396
# Apply scale_ to input. We do this by broadcasting scale.
3397
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
3392
3398
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
3393
- 6.0 / group_max , [GROUP_LOAD , 1 ]
3399
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
3394
3400
)
3395
3401
# Reshape back to a flat array.
3396
3402
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -3458,7 +3464,7 @@ def _mega_fp4_quantize_kernel(
3458
3464
3459
3465
tl .store (
3460
3466
scale + actual_scale_offset_permute ,
3461
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
3467
+ scale_ .to (tl .uint8 , bitcast = True ),
3462
3468
# Prevent writing outside this chunk or the main array.
3463
3469
mask = (row_idx < M )
3464
3470
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 )))
@@ -3654,13 +3660,14 @@ def _mega_fp4_quantize_kernel_with_tensor_idx(
3654
3660
input_global_scale_tensor + tensor_idx , mask = tensor_idx_guard
3655
3661
)
3656
3662
# Next we scale A in preparation for quantization.
3657
- scale_ = group_max / 6.0 * input_global_scale
3663
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
3658
3664
# Prevent infinite values in log.
3659
3665
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
3660
3666
3661
3667
# Apply scale_ to input. We do this by broadcasting scale.
3668
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
3662
3669
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
3663
- 6.0 / group_max , [GROUP_LOAD , 1 ]
3670
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
3664
3671
)
3665
3672
# Reshape back to a flat array.
3666
3673
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -3728,7 +3735,7 @@ def _mega_fp4_quantize_kernel_with_tensor_idx(
3728
3735
3729
3736
tl .store (
3730
3737
scale + actual_scale_offset_permute ,
3731
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
3738
+ scale_ .to (tl .uint8 , bitcast = True ),
3732
3739
# Prevent writing outside this chunk or the main array.
3733
3740
mask = (row_idx < M )
3734
3741
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 )))
@@ -4238,13 +4245,14 @@ def _kernel_nvfp4_quantize_stacked_rms(
4238
4245
group_max = tl .max (tl .abs (a_groups ), axis = 1 ).to (tl .float32 )
4239
4246
4240
4247
# Next we scale A in preparation for quantization.
4241
- scale_ = group_max / 6.0 * input_global_scale
4248
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
4242
4249
# Prevent infinite values in log.
4243
4250
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
4244
4251
4245
4252
# Apply scale_ to input. We do this by broadcasting scale.
4253
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
4246
4254
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
4247
- 6.0 / group_max , [GROUP_LOAD , 1 ]
4255
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
4248
4256
)
4249
4257
# Reshape back to a flat array.
4250
4258
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -4312,7 +4320,7 @@ def _kernel_nvfp4_quantize_stacked_rms(
4312
4320
4313
4321
tl .store (
4314
4322
scale + actual_scale_offset_permute ,
4315
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
4323
+ scale_ .to (tl .uint8 , bitcast = True ),
4316
4324
# Prevent writing outside this chunk or the main array.
4317
4325
mask = (row_idx < M )
4318
4326
& (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 )))
@@ -4580,13 +4588,14 @@ def _mega_fp4_pack_kernel(
4580
4588
group_max = tl .max (tl .abs (a_groups ), axis = 1 ).to (tl .float32 )
4581
4589
4582
4590
# Next we scale A in preparation for quantization.
4583
- scale_ = group_max / 6.0 * input_global_scale
4591
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
4584
4592
# Prevent infinite values in log.
4585
4593
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
4586
4594
4587
4595
# Apply scale_ to input. We do this by broadcasting scale.
4596
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
4588
4597
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
4589
- 6.0 / group_max , [GROUP_LOAD , 1 ]
4598
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
4590
4599
)
4591
4600
# Reshape back to a flat array.
4592
4601
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -4638,7 +4647,7 @@ def _mega_fp4_pack_kernel(
4638
4647
4639
4648
tl .store (
4640
4649
out + exp_offset ,
4641
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
4650
+ scale_ .to (tl .uint8 , bitcast = True ),
4642
4651
# Prevent writing outside this chunk or the main array.
4643
4652
mask = (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ) + SCALE_SHIFT ))
4644
4653
& (exp_offset < SCALE_SIZE + SCALE_SHIFT ),
@@ -4792,13 +4801,14 @@ def _mega_fp4_pack_kernel_per_tensor(
4792
4801
input_global_scale_tensor + tensor_idx , mask = tensor_idx_guard
4793
4802
)
4794
4803
# Next we scale A in preparation for quantization.
4795
- scale_ = group_max / 6.0 * input_global_scale
4804
+ scale_ = ( group_max / 6.0 * input_global_scale ). to ( tl . float8e4nv )
4796
4805
# Prevent infinite values in log.
4797
4806
group_max = tl .where (group_max == 0 , BF16_MIN_NORMAL , group_max )
4798
4807
4799
4808
# Apply scale_ to input. We do this by broadcasting scale.
4809
+ # scaled_a = a * global_scale (fp32) / local_scale (fp8)
4800
4810
scaled_a = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ]) * tl .reshape (
4801
- 6.0 / group_max , [GROUP_LOAD , 1 ]
4811
+ input_global_scale / scale_ , [GROUP_LOAD , 1 ]
4802
4812
)
4803
4813
# Reshape back to a flat array.
4804
4814
scaled_a = tl .reshape (scaled_a , [GROUP_LOAD * GROUP_SIZE ])
@@ -4850,7 +4860,7 @@ def _mega_fp4_pack_kernel_per_tensor(
4850
4860
4851
4861
tl .store (
4852
4862
out + exp_offset ,
4853
- scale_ .to (tl .float8e4nv ). to ( tl . uint8 , bitcast = True ),
4863
+ scale_ .to (tl .uint8 , bitcast = True ),
4854
4864
# Prevent writing outside this chunk or the main array.
4855
4865
mask = (exp_offset < (SCALE_CHUNK_SIZE * (pid + 1 ) + SCALE_SHIFT ))
4856
4866
& (exp_offset < (SCALE_SIZE + SCALE_SHIFT )),
0 commit comments