Skip to content

Commit 9b5fba1

Browse files
JChunXfacebook-github-bot
authored andcommitted
MI350X FP8 triton patch (#4889)
Summary: Pull Request resolved: #4889 X-link: facebookresearch/FBGEMM#1913 MI350X has some incompatibilities with the current AMD triton FP8 setup, namely: Moving from FNUZ to OCP format Instruction set not supporting BLOCK_K=128 & matrix_instr_nonkdim=16 combo, leading to crash. Fix by adding gates Reviewed By: bilal Differential Revision: D81180838 fbshipit-source-id: 5f5b87cc62874fc19d7580e5b3e7ead2a75f4d55
1 parent c388f62 commit 9b5fba1

File tree

1 file changed

+63
-250
lines changed

1 file changed

+63
-250
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 63 additions & 250 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,8 +1246,19 @@ def matmul_fp8_row(
12461246
# View inputs into proper torch fp8 dtype.
12471247
if torch.version.cuda:
12481248
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
1249+
elif torch.version.hip:
1250+
if torch.cuda.get_device_capability() < (9, 5):
1251+
assert a.dtype in (
1252+
torch.float8_e4m3fnuz,
1253+
torch.float8_e5m2fnuz,
1254+
)
1255+
else:
1256+
assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
12491257
else:
1250-
assert a.dtype in (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
1258+
assert a.dtype in (
1259+
torch.float8_e4m3fnuz,
1260+
torch.float8_e5m2fnuz,
1261+
)
12511262
assert b.dtype == pt_fp8_dtype
12521263
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
12531264
prep_matmul(a, b, dot_out_dtype)
@@ -3808,259 +3819,61 @@ def get_full_non_persistent_tuning_space():
38083819

38093820

38103821
MATMUL_CONFIGS_NON_PERSISTENT: list[Config] = get_full_non_persistent_tuning_space()
3822+
# (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
3823+
_MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
3824+
(16, 16, 256, 1, 1, 8, 16, 2, 2, 2),
3825+
(16, 16, 256, 1, 1, 0, 16, 2, 2, 2),
3826+
(32, 64, 512, 1, 1, 2, 16, 2, 8, 2),
3827+
(64, 64, 256, 1, 1, 2, 16, 2, 4, 2),
3828+
(256, 256, 128, 32, 1, 2, 16, 1, 8, 2),
3829+
(256, 256, 128, 2, 1, 0, 32, 2, 8, 2),
3830+
(256, 256, 128, 1, 1, 0, 32, 2, 8, 2),
3831+
(256, 256, 128, 2, 1, 0, 16, 1, 8, 2),
3832+
(256, 256, 64, 2, 1, 2, 16, 1, 8, 2),
3833+
(128, 256, 64, 2, 1, 2, 16, 1, 4, 2),
3834+
(256, 128, 128, 4, 1, 0, 16, 1, 8, 2),
3835+
(128, 128, 128, 1, 1, 2, 16, 2, 4, 2),
3836+
(128, 128, 256, 1, 1, 2, 16, 2, 8, 2),
3837+
(128, 128, 64, 4, 1, 2, 16, 2, 4, 2),
3838+
(128, 128, 64, 1, 1, 2, 16, 2, 4, 2),
3839+
(128, 64, 64, 4, 1, 0, 16, 2, 4, 2),
3840+
(128, 64, 64, 1, 1, 0, 16, 2, 4, 2),
3841+
(256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
3842+
]
3843+
3844+
3845+
def _should_skip_config(block_k, matrix_instr_nonkdim):
3846+
"""Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
3847+
try:
3848+
return (
3849+
block_k == 64
3850+
and matrix_instr_nonkdim == 16
3851+
and torch.version.hip is not None
3852+
and torch.cuda.get_device_capability() >= (9, 5)
3853+
)
3854+
except RuntimeError:
3855+
# If no HIP GPUs are available, we can't check device capability
3856+
# so we don't skip any configs
3857+
return False
3858+
3859+
38113860
MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
38123861
triton.Config(
38133862
{
3814-
"BLOCK_M": 16,
3815-
"BLOCK_N": 16,
3816-
"BLOCK_K": 256,
3817-
"GROUP_M": 1,
3818-
"SPLIT_K": 1,
3819-
"waves_per_eu": 8,
3820-
"matrix_instr_nonkdim": 16,
3821-
"kpack": 2,
3822-
},
3823-
num_warps=2,
3824-
num_stages=2,
3825-
),
3826-
triton.Config(
3827-
{
3828-
"BLOCK_M": 16,
3829-
"BLOCK_N": 16,
3830-
"BLOCK_K": 256,
3831-
"GROUP_M": 1,
3832-
"SPLIT_K": 1,
3833-
"waves_per_eu": 0,
3834-
"matrix_instr_nonkdim": 16,
3835-
"kpack": 2,
3836-
},
3837-
num_warps=2,
3838-
num_stages=2,
3839-
),
3840-
triton.Config(
3841-
{
3842-
"BLOCK_M": 32,
3843-
"BLOCK_N": 64,
3844-
"BLOCK_K": 512,
3845-
"GROUP_M": 1,
3846-
"SPLIT_K": 1,
3847-
"waves_per_eu": 2,
3848-
"matrix_instr_nonkdim": 16,
3849-
"kpack": 2,
3850-
},
3851-
num_warps=8,
3852-
num_stages=2,
3853-
),
3854-
triton.Config(
3855-
{
3856-
"BLOCK_M": 64,
3857-
"BLOCK_N": 64,
3858-
"BLOCK_K": 256,
3859-
"GROUP_M": 1,
3860-
"SPLIT_K": 1,
3861-
"waves_per_eu": 2,
3862-
"matrix_instr_nonkdim": 16,
3863-
"kpack": 2,
3863+
"BLOCK_M": block_m,
3864+
"BLOCK_N": block_n,
3865+
"BLOCK_K": block_k,
3866+
"GROUP_M": group_m,
3867+
"SPLIT_K": split_k,
3868+
"waves_per_eu": waves_per_eu,
3869+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
3870+
"kpack": kpack,
38643871
},
3865-
num_warps=4,
3866-
num_stages=2,
3867-
),
3868-
triton.Config(
3869-
{
3870-
"BLOCK_M": 256,
3871-
"BLOCK_N": 256,
3872-
"BLOCK_K": 128,
3873-
"GROUP_M": 32,
3874-
"SPLIT_K": 1,
3875-
"waves_per_eu": 2,
3876-
"matrix_instr_nonkdim": 16,
3877-
"kpack": 1,
3878-
},
3879-
num_warps=8,
3880-
num_stages=2,
3881-
),
3882-
triton.Config(
3883-
{
3884-
"BLOCK_M": 256,
3885-
"BLOCK_N": 256,
3886-
"BLOCK_K": 128,
3887-
"GROUP_M": 2,
3888-
"SPLIT_K": 1,
3889-
"waves_per_eu": 0,
3890-
"matrix_instr_nonkdim": 32,
3891-
"kpack": 2,
3892-
},
3893-
num_warps=8,
3894-
num_stages=2,
3895-
),
3896-
triton.Config(
3897-
{
3898-
"BLOCK_M": 256,
3899-
"BLOCK_N": 256,
3900-
"BLOCK_K": 128,
3901-
"GROUP_M": 1,
3902-
"SPLIT_K": 1,
3903-
"waves_per_eu": 0,
3904-
"matrix_instr_nonkdim": 32,
3905-
"kpack": 2,
3906-
},
3907-
num_warps=8,
3908-
num_stages=2,
3909-
),
3910-
triton.Config(
3911-
{
3912-
"BLOCK_M": 256,
3913-
"BLOCK_N": 256,
3914-
"BLOCK_K": 128,
3915-
"GROUP_M": 2,
3916-
"SPLIT_K": 1,
3917-
"waves_per_eu": 0,
3918-
"matrix_instr_nonkdim": 16,
3919-
"kpack": 1,
3920-
},
3921-
num_warps=8,
3922-
num_stages=2,
3923-
),
3924-
triton.Config(
3925-
{
3926-
"BLOCK_M": 256,
3927-
"BLOCK_N": 256,
3928-
"BLOCK_K": 64,
3929-
"GROUP_M": 2,
3930-
"SPLIT_K": 1,
3931-
"waves_per_eu": 2,
3932-
"matrix_instr_nonkdim": 16,
3933-
"kpack": 1,
3934-
},
3935-
num_warps=8,
3936-
num_stages=2,
3937-
),
3938-
triton.Config(
3939-
{
3940-
"BLOCK_M": 128,
3941-
"BLOCK_N": 256,
3942-
"BLOCK_K": 64,
3943-
"GROUP_M": 2,
3944-
"SPLIT_K": 1,
3945-
"waves_per_eu": 2,
3946-
"matrix_instr_nonkdim": 16,
3947-
"kpack": 1,
3948-
},
3949-
num_warps=4,
3950-
num_stages=2,
3951-
),
3952-
triton.Config(
3953-
{
3954-
"BLOCK_M": 256,
3955-
"BLOCK_N": 128,
3956-
"BLOCK_K": 128,
3957-
"GROUP_M": 4,
3958-
"SPLIT_K": 1,
3959-
"waves_per_eu": 0,
3960-
"matrix_instr_nonkdim": 16,
3961-
"kpack": 1,
3962-
},
3963-
num_warps=8,
3964-
num_stages=2,
3965-
),
3966-
triton.Config(
3967-
{
3968-
"BLOCK_M": 128,
3969-
"BLOCK_N": 128,
3970-
"BLOCK_K": 128,
3971-
"GROUP_M": 1,
3972-
"SPLIT_K": 1,
3973-
"waves_per_eu": 2,
3974-
"matrix_instr_nonkdim": 16,
3975-
"kpack": 2,
3976-
},
3977-
num_warps=4,
3978-
num_stages=2,
3979-
),
3980-
triton.Config(
3981-
{
3982-
"BLOCK_M": 128,
3983-
"BLOCK_N": 128,
3984-
"BLOCK_K": 256,
3985-
"GROUP_M": 1,
3986-
"SPLIT_K": 1,
3987-
"waves_per_eu": 2,
3988-
"matrix_instr_nonkdim": 16,
3989-
"kpack": 2,
3990-
},
3991-
num_warps=8,
3992-
num_stages=2,
3993-
),
3994-
triton.Config(
3995-
{
3996-
"BLOCK_M": 128,
3997-
"BLOCK_N": 128,
3998-
"BLOCK_K": 64,
3999-
"GROUP_M": 4,
4000-
"SPLIT_K": 1,
4001-
"waves_per_eu": 2,
4002-
"matrix_instr_nonkdim": 16,
4003-
"kpack": 2,
4004-
},
4005-
num_warps=4,
4006-
num_stages=2,
4007-
),
4008-
triton.Config(
4009-
{
4010-
"BLOCK_M": 128,
4011-
"BLOCK_N": 128,
4012-
"BLOCK_K": 64,
4013-
"GROUP_M": 1,
4014-
"SPLIT_K": 1,
4015-
"waves_per_eu": 2,
4016-
"matrix_instr_nonkdim": 16,
4017-
"kpack": 2,
4018-
},
4019-
num_warps=4,
4020-
num_stages=2,
4021-
),
4022-
triton.Config(
4023-
{
4024-
"BLOCK_M": 128,
4025-
"BLOCK_N": 64,
4026-
"BLOCK_K": 64,
4027-
"GROUP_M": 4,
4028-
"SPLIT_K": 1,
4029-
"waves_per_eu": 0,
4030-
"matrix_instr_nonkdim": 16,
4031-
"kpack": 2,
4032-
},
4033-
num_warps=4,
4034-
num_stages=2,
4035-
),
4036-
triton.Config(
4037-
{
4038-
"BLOCK_M": 128,
4039-
"BLOCK_N": 64,
4040-
"BLOCK_K": 64,
4041-
"GROUP_M": 1,
4042-
"SPLIT_K": 1,
4043-
"waves_per_eu": 0,
4044-
"matrix_instr_nonkdim": 16,
4045-
"kpack": 2,
4046-
},
4047-
num_warps=4,
4048-
num_stages=2,
4049-
),
4050-
triton.Config(
4051-
{
4052-
"BLOCK_M": 256,
4053-
"BLOCK_N": 128,
4054-
"BLOCK_K": 128,
4055-
"GROUP_M": 1,
4056-
"SPLIT_K": 1,
4057-
"waves_per_eu": 2,
4058-
"matrix_instr_nonkdim": 16,
4059-
"kpack": 1,
4060-
},
4061-
num_warps=8,
4062-
num_stages=2,
4063-
),
3872+
num_warps=num_warps,
3873+
num_stages=num_stages,
3874+
)
3875+
for block_m, block_n, block_k, group_m, split_k, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
3876+
if not _should_skip_config(block_k, matrix_instr_nonkdim)
40643877
]
40653878

40663879
# Set this to enable full autotuning for proper benchmarking.

0 commit comments

Comments
 (0)