From 3e9fe77f12c064607496d2b08c97082c99136f26 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 11 Feb 2025 23:03:36 +0800 Subject: [PATCH 01/18] add sage attn sm90 kernels --- csrc/gpu/sage_attn_kernels/sageattn.cc | 540 ++++ csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 953 ++++++ .../sageattn_qk_int_sv_f16_kernel.cu | 1690 +++++++++++ .../sageattn_qk_int_sv_f8_kernel.cu | 1239 ++++++++ .../sageattn_qk_int_sv_f8_kernel_sm90.cu | 878 ++++++ csrc/gpu/sage_attn_kernels/sageattn_utils.cuh | 2671 +++++++++++++++++ csrc/setup_cuda.py | 12 +- 7 files changed, 7982 insertions(+), 1 deletion(-) create mode 100644 csrc/gpu/sage_attn_kernels/sageattn.cc create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_fused.cu create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_utils.cuh diff --git a/csrc/gpu/sage_attn_kernels/sageattn.cc b/csrc/gpu/sage_attn_kernels/sageattn.cc new file mode 100644 index 000000000000..10af0fbf5dfa --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn.cc @@ -0,0 +1,540 @@ +#include "paddle/extension.h" + + +// +// ============== fp16 kernels registry, for sm80 arch ============== +// +// impl: sageattn_qk_int_sv_f16_kernel.cu +// attn buffer kernel +// std::vector qk_int8_sv_f16_accum_f16_attn_buf_fwd( +// paddle::Tensor& query, +// paddle::Tensor& key, +// paddle::Tensor& value, +// paddle::Tensor& output, +// paddle::Tensor& query_scale, +// paddle::Tensor& key_scale, +// int tensor_layout, +// int is_causal, +// int qk_quant_gran, +// float sm_scale, +// int return_lse); + +// std::vector> qk_int8_sv_f16_accum_f16_attn_buf_InferShape( +// std::vector query_shape, +// std::vector key_shape, +// std::vector value_shape, +// std::vector output_shape, +// std::vector query_scale_shape, +// std::vector key_scale_shape) { + +// // force layout: NHD: [bsz, seq_len, num_heads, head_dim] +// int64_t bsz = query_shape[0]; +// int64_t seq_len = query_shape[1]; +// int64_t h_qo = query_shape[2]; + +// std::vector return_shape = {bsz, h_qo, seq_len}; +// return {return_shape}; +// } + +// std::vector qk_int8_sv_f16_accum_f16_attn_buf_InferDtype( +// paddle::DataType A_dtype, +// paddle::DataType B_dtype, +// paddle::DataType C_dtype, +// paddle::DataType D_dtype, +// paddle::DataType E_dtype, +// paddle::DataType F_dtype) { +// return {paddle::DataType::FLOAT32}; +// } + +// PD_BUILD_OP(qk_int8_sv_f16_accum_f16_attn_buf) +// .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) +// .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) +// .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace +// .Attrs({"tensor_layout: int", +// "is_causal: int", +// "qk_quant_gran: int", +// "sm_scale: float", +// "return_lse: int"}) +// .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f16_attn_buf_fwd)) +// .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f16_attn_buf_InferShape)) +// .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f16_attn_buf_InferDtype)); + + +// attn forward kernel: sv f16 accumulator f32 +std::vector qk_int8_sv_f16_accum_f32_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + + +std::vector> qk_int8_sv_f16_accum_f32_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f16_accum_f32_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f16_accum_f32_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f32_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f32_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f32_attn_InferDtype)); + +// +// ============== fp8 kernels registry, for sm89 arch ============== +// + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + paddle::Tensor& value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape, + std::vector value_mean_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype, + paddle::DataType H_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale", "value_mean"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "out8", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}, {"value_mean", "out8"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype)); + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype)); + + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype)); + +// +// ============== fp8 kernels registry, for sm90 arch ============== +// + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + +std::vector> qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype)); + + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype)); + + +// +// ============== fused kernels registry ============== +// + +void quant_per_block_int8_fuse_sub_mean_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& mean, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int tensor_layout); + +// quant_per_block_int8_fuse_sub_mean_cuda_fwd does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(quant_per_block_int8_fuse_sub_mean_cuda) + .Inputs({"input", "mean", "output", "scale"}) + .Outputs({"out1", "out2", "out3", "out4"}) + .SetInplaceMap({{"input", "out1"}, {"mean", "out2"}, {"output", "out3"}, {"scale", "out4"}}) // Inplace + .Attrs({"block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_fuse_sub_mean_cuda_fwd)); + + +void quant_per_warp_int8_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int warp_block_size, + int tensor_layout); + +// quant_per_warp_int8_cuda_fwd does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(quant_per_warp_int8_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"block_size: int", "warp_block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_warp_int8_cuda_fwd)); + + +void quant_per_block_int8_cuda_scale_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + float sm_scale, + int block_size, + int tensor_layout); + +// quant_per_block_int8_cuda_scale does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(quant_per_block_int8_cuda_scale) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_scale_fwd)); + + +void quant_per_block_int8_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int tensor_layout); + +// quant_per_block_int8_cuda does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(quant_per_block_int8_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_fwd)); + + +void transpose_pad_permute_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + int tensor_layout); + +// transpose_pad_permute_cuda_fwd does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(transpose_pad_permute_cuda) + .Inputs({"input", "output"}) + .Outputs({"out1", "out2"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}}) // Inplace + .Attrs({"tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(transpose_pad_permute_cuda_fwd)); + + +void scale_fuse_quant_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int num_tokens, + float scale_max, + int tensor_layout); + +// scale_fuse_quant_cuda_fwd does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(scale_fuse_quant_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(scale_fuse_quant_cuda_fwd)); + + +void mean_scale_fuse_quant_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& mean, + paddle::Tensor& scale, + int num_tokens, + float scale_max, + int tensor_layout); + +// mean_scale_fuse_quant_cuda_fwd does not have any return +// so we don't implement infer type & shape function here. + +PD_BUILD_OP(mean_scale_fuse_quant_cuda) + .Inputs({"input", "output", "mean", "scale"}) + .Outputs({"out1", "out2", "out3", "out4"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}}) // Inplace + .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(mean_scale_fuse_quant_cuda_fwd)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu new file mode 100644 index 000000000000..501e0eb89641 --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -0,0 +1,953 @@ +#include +#include +#include + +#include "sageattn_utils.cuh" +#include "paddle/extension.h" + +enum class QuantType +{ + kInt8, + kInt4, +}; + +template +__device__ __forceinline__ float convert_to_float(T val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __half2float(val); + } + else if constexpr (std::is_same::value) + { + return __bfloat162float(val); + } +} + +template +__device__ __forceinline__ T convert_from_float(float val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __float2half_rn(val); + } + else if constexpr (std::is_same::value) + { + return __float2bfloat16_rn(val); + } +} + + +// +// =========== kernel impl zone =========== +// +// Notice: the address in paddle is aligned in 4 bytes, not 16 bytes +// so theoretically the `float4(xx) = xx` way cannot load data from global memory +// peacefully, instead it will trigger CUDA 719: address misaligned fault. +// Thus, we will use pragma unroll macro and load data in a for loop, which introduces +// extra latency, but works in paddle. +template +__global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int8_t *__restrict__ output, float *__restrict__ scale, float sm_scale, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + constexpr uint32_t num_threads_per_token = head_dim / pack_size; + + static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to warp size"); + + T x_val[num_pack_per_thread][8]; + T mean_val[8]; + float x_val_float[num_pack_per_thread][8]; + float mean_val_float[8]; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; + float *scale_ptr_base = scale + batch_id * stride_bz_scale + head_id * stride_h_scale + bx; + + if constexpr (sub_mean) + { + // *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); + // for unable-align reasons, we unroll it manually. +#pragma unroll + for (int ii = 0; ii < 8; ii++) { + mean_val[ii] = mean_ptr_base[ii]; + } + +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + mean_val_float[j] = convert_to_float(mean_val[j]); + } + } + + constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread; + + // load the data + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + // *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); + // for unable-align reasons, we unroll it manually. +#pragma unroll + for (int ii = 0; ii < 8; ii++) { + x_val[i][ii] = *(input_ptr_base + i * iter_stride * stride_seq_input + ii); + } +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = convert_to_float(x_val[i][j]); + } + + if constexpr (sub_mean) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] -= mean_val_float[j]; + } + } + + if constexpr (has_sm_scale) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] *= sm_scale; + } + } + } + else + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = 0.0f; + } + } + } + + float amax_val = 0.0000001f; // prevent from dividing by zero + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + amax_val = fmaxf(amax_val, fabsf(x_val_float[i][j])); + } + } + + __shared__ float s_amax; + const float block_amax_val = sageattn::blockReduceMax(amax_val); + if (thread_id == 0) + { + s_amax = block_amax_val; + scale_ptr_base[0] = s_amax / 127.0f; + } + + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + + char4 o_val[num_pack_per_thread][2]; + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j += 1) + { + o_val[i][j] = make_char4( + float_to_int8_rn(x_val_float[i][j * 4 + 0] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 1] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 2] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 3] * tmp_scale) + ); + } + } + + // int8 result +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + + if (thread_base_token + i * iter_stride < num_tokens) + { + *reinterpret_cast(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast(&o_val[i][0]); + } + } +} + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + + // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 + // permute on the seq dimension for fp8 mma + uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; + uint32_t smem_load_row_mod = (thread_id / num_threads_per_token) % 16; + uint32_t smem_load_row = smem_load_row_base + (smem_load_row_mod / 8) * 2 + ((smem_load_row_mod / 2) % 4) * 4 + (smem_load_row_mod % 2); + + constexpr cp_async::SharedMemFillMode fill_mode = pad_zero ? cp_async::SharedMemFillMode::kFillZero : cp_async::SharedMemFillMode::kNoFill; + cp_async::pred_load_128b(shared_load[smem_load_row] + thread_id % num_threads_per_token * pack_size, input_ptr_base, thread_base_token < num_tokens); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + + // *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); + // for unable-align reasons, we unroll it manually. +#pragma unroll + for (int i = 0; i < 8; i++) { + *(output_ptr_base + i) = shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size + i]; // TODO: not debugged, maybe some problem + } +} + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + // *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (int ii = 0; ii < 8; ii++) { + x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged + } +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = sageattn::blockReduceMax(max_val); + float block_min_val = sageattn::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = sageattn::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + // *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (int ii = 0; ii < 8; ii++) { + x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged + } +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } + + floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + *(uint2*)(output_ptr_base + i * gmem_stride) = *(uint2*)(&x_val_fp8[0]); + } +} + + +// +// =========== kernel API zone =========== +// +void quant_per_block_int8_fuse_sub_mean_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& mean, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(mean); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, paddle::DataType::INT8); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int head_dim = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.shape()[1]; + num_heads = input.shape()[2]; + stride_seq_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_seq_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_tokens = input.shape()[2]; + num_heads = input.shape()[1]; + stride_seq_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_seq_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + auto input_dtype = input.dtype(); + auto mean_dtype = mean.dtype(); + + PD_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + std::cout << "resources: " << (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE << " " << num_heads << " " <<<>>( + reinterpret_cast(input.data()), + reinterpret_cast(mean.data()), + output.data(), + reinterpret_cast(scale.data()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + mean.strides()[0], mean.strides()[1], + stride_bz_output, stride_seq_output, stride_h_output, + scale.strides()[0], scale.strides()[1] + ); + }); + }); + }); +} + +void quant_per_warp_int8_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int warp_block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, paddle::DataType::INT8); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int head_dim = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.shape()[1]; + num_heads = input.shape()[2]; + stride_seq_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_seq_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_tokens = input.shape()[2]; + num_heads = input.shape()[1]; + stride_seq_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_seq_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + auto input_dtype = input.dtype(); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE)); + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE), num_heads, batch_size); + constexpr int num_pack_per_thread = (WARP_BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + dim3 block(WARP_BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data()), + nullptr, + output.data(), + reinterpret_cast(scale.data()), + 0.0, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.strides()[0], scale.strides()[1] + ); + }); + }); + }); + }); +} + +void quant_per_block_int8_cuda_scale_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + float sm_scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, paddle::DataType::INT8); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int head_dim = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.shape()[1]; + num_heads = input.shape()[2]; + stride_seq_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_seq_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_tokens = input.shape()[2]; + num_heads = input.shape()[1]; + stride_seq_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_seq_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + auto input_dtype = input.dtype(); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data()), + nullptr, + output.data(), + reinterpret_cast(scale.data()), + sm_scale, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.strides()[0], scale.strides()[1] + ); + }); + }); + }); +} + +void quant_per_block_int8_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, paddle::DataType::INT8); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int head_dim = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.shape()[1]; + num_heads = input.shape()[2]; + stride_seq_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_seq_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_tokens = input.shape()[2]; + num_heads = input.shape()[1]; + stride_seq_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_seq_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + auto input_dtype = input.dtype(); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data()), + nullptr, + output.data(), + reinterpret_cast(scale.data()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.strides()[0], scale.strides()[1] + ); + }); + }); + }); +} + +void transpose_pad_permute_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.shape()[0]; + const int head_dim = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.shape()[1]; + num_heads = input.shape()[2]; + stride_seq_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_d_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.shape()[2]; + num_heads = input.shape()[1]; + stride_seq_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_d_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.dtype(); + auto output_dtype = output.dtype(); + + PD_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + TransposePadPermuteKernel<<>>( + reinterpret_cast(input.data()), + reinterpret_cast(output.data()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} + +void scale_fuse_quant_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int num_tokens_padded = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.shape()[2]; + head_dim = input.shape()[1]; + stride_d_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_d_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_heads = input.shape()[1]; + head_dim = input.shape()[2]; + stride_d_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_d_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.dtype(); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, false, c_type><<>>( + reinterpret_cast(input.data()), + reinterpret_cast(output.data()), + nullptr, + reinterpret_cast(scale.data()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.strides()[0], scale.strides()[1] + ); + }); +} + +void mean_scale_fuse_quant_cuda_fwd( + paddle::Tensor& input, + paddle::Tensor& output, + paddle::Tensor& mean, + paddle::Tensor& scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(mean); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(mean, paddle::DataType::FLOAT32); + CHECK_DTYPE(scale, paddle::DataType::FLOAT32); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(scale, 3); + + const int batch_size = input.shape()[0]; + const int num_tokens_padded = input.shape()[3]; + + int stride_bz_input = input.strides()[0]; + int stride_bz_output = output.strides()[0]; + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.shape()[2]; + head_dim = input.shape()[1]; + stride_d_input = input.strides()[1]; + stride_h_input = input.strides()[2]; + stride_d_output = output.strides()[1]; + stride_h_output = output.strides()[2]; + } + else + { + num_heads = input.shape()[1]; + head_dim = input.shape()[2]; + stride_d_input = input.strides()[2]; + stride_h_input = input.strides()[1]; + stride_d_output = output.strides()[2]; + stride_h_output = output.strides()[1]; + } + + CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.dtype(); + + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, true, c_type><<>>( + reinterpret_cast(input.data()), + reinterpret_cast(output.data()), + reinterpret_cast(mean.data()), + reinterpret_cast(scale.data()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + mean.strides()[0], mean.strides()[1], + scale.strides()[0], scale.strides()[1] + ); + }); +} \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu new file mode 100644 index 000000000000..e6e5d0daa270 --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu @@ -0,0 +1,1690 @@ +#include + +#include "paddle/extension.h" + +// #include "sageattn.h" +#include "sageattn_utils.cuh" + +#define PACK_SIZE_QK 16 // as if it is int8 +#define PACK_SIZE_V 16 // fp8 +#define PACK_SIZE_O 8 // fp16 + +// treat as if int8 tensor core +#define MMA_QK_M 16 +#define MMA_QK_N 16 +#define MMA_QK_K 32 + +// fp8 tensor core +#define MMA_SV_M 16 +#define MMA_SV_N 16 +#define MMA_SV_K 32 + +// qk_int_sv_f16_buffer +// when instantiating, the head dim = 64, which makes the V_STRIDE = 64, then div 16 = 4, +// which triggered the compiling fault. +// it is the macro: PACK_SIZE_V and MMA_SV_K's problem, so we will redefine them here: +#ifdef PACK_SIZE_V +#define PACK_SIZE_V 8 +#endif + +#ifdef MMA_SV_K +#define MMA_SV_K 16 +#endif + +// inner impl +template +__global__ void qk_int_sv_f16_attn_buffer_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + constexpr uint32_t V_SMEM_STRIDE = head_dim; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + half RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + float m_buf[num_tiles_q][2]; // buffer for m + float RO_buf[num_tiles_q][num_tiles_v][8]; // buffer for RO + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO_buf[fq][fv][k] = 0.0f; + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + m_buf[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + + uint32_t num_flush_times = div_ceil(num_iterations, Buffer_Iter) - (num_iterations % Buffer_Iter == 1); // leave at least two iterations for the last flush + uint32_t iter = 1; + +#pragma unroll + for (uint32_t flush_time = 0; flush_time < num_flush_times - 1; flush_time++) + { +#pragma unroll + for (; iter <= (flush_time + 1) * Buffer_Iter; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + + #pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + #pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + #pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + // do not apply causal mask and out of bound mask for these iterations + K_idx_lane_base += CTA_K; + + update_mdo(RS_f32, RO, m, d, sm_scale); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + + __syncthreads(); + // load V + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // update buffer +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + float o_scale = math::ptx_exp2(m_buf[fq][k] - m[fq][k]); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // update buffer + RO_buf[fq][fv][k * 2 + 0] = RO_buf[fq][fv][k * 2 + 0] * o_scale + __half2float(RO[fq][fv][k * 2 + 0]); + RO_buf[fq][fv][k * 2 + 1] = RO_buf[fq][fv][k * 2 + 1] * o_scale + __half2float(RO[fq][fv][k * 2 + 1]); + RO_buf[fq][fv][k * 2 + 4] = RO_buf[fq][fv][k * 2 + 4] * o_scale + __half2float(RO[fq][fv][k * 2 + 4]); + RO_buf[fq][fv][k * 2 + 5] = RO_buf[fq][fv][k * 2 + 5] * o_scale + __half2float(RO[fq][fv][k * 2 + 5]); + + // update m_buf + m_buf[fq][k] = m[fq][k]; + + // clear RO + *((int32_t*)&RO[fq][fv][k * 2 + 0]) = 0; + *((int32_t*)&RO[fq][fv][k * 2 + 4]) = 0; + } + } + } + } + +#pragma unroll + for (; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + // do not apply causal mask and out of bound mask for these iterations + K_idx_lane_base += CTA_K; + + update_mdo(RS_f32, RO, m, d, sm_scale); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + + __syncthreads(); + // load V + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + update_mdo(RS_f32, RO, m, d, original_sm_scale); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + + __syncthreads(); + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // check out of bound in the last iter + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + update_mdo(RS_f32, RO, m, d, original_sm_scale); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + + __syncthreads(); + + } + + // update buffer +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + float o_scale = math::ptx_exp2(m_buf[fq][k] - m[fq][k]); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // update buffer + RO_buf[fq][fv][k * 2 + 0] = RO_buf[fq][fv][k * 2 + 0] * o_scale + __half2float(RO[fq][fv][k * 2 + 0]); + RO_buf[fq][fv][k * 2 + 1] = RO_buf[fq][fv][k * 2 + 1] * o_scale + __half2float(RO[fq][fv][k * 2 + 1]); + RO_buf[fq][fv][k * 2 + 4] = RO_buf[fq][fv][k * 2 + 4] * o_scale + __half2float(RO[fq][fv][k * 2 + 4]); + RO_buf[fq][fv][k * 2 + 5] = RO_buf[fq][fv][k * 2 + 5] * o_scale + __half2float(RO[fq][fv][k * 2 + 5]); + + // update m_buf + // m_buf[fq][k] = m[fq][k]; + + // // clear RO + // *((int32_t*)&RO[fq][fv][k * 2 + 0]) = 0; + // *((int32_t*)&RO[fq][fv][k * 2 + 4]) = 0; + } + } + } + + // TODO: thread block sync mdo state for num_warps_k > 0 + + normalize_d(RO_buf, m, d); + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + // convert RO_buf to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO_buf[fq][fv])[k]); + } + else if constexpr (std::is_same::value) + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO_buf[fq][fv])[k]); + } + } + + ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + // ! permuted, make sure you know what you are doing + ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; + ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + // ! this only works for num_tiles_q = 2 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); // TODO: here has some bug. + } + } + +} + +// impl -> see sageattn.h file +// tensor_layout 0 for [B, N, H, D] (NHD, b, s, head, dim), +// 1 for [B, H, N, D] (HND) +// std::vector qk_int8_sv_f16_accum_f16_attn_buf_fwd(paddle::Tensor& query, +// paddle::Tensor& key, +// paddle::Tensor& value, +// paddle::Tensor& output, +// paddle::Tensor& query_scale, +// paddle::Tensor& key_scale, +// int tensor_layout, +// int is_causal, +// int qk_quant_gran, +// float sm_scale, +// int return_lse) +// { +// CHECK_CUDA(query); +// CHECK_CUDA(key); +// CHECK_CUDA(value); +// CHECK_CUDA(output); +// CHECK_CUDA(query_scale); +// CHECK_CUDA(key_scale); + +// CHECK_CONTIGUOUS(query); +// CHECK_CONTIGUOUS(key); +// CHECK_LASTDIM_CONTIGUOUS(value); +// CHECK_LASTDIM_CONTIGUOUS(output); +// CHECK_CONTIGUOUS(query_scale); +// CHECK_CONTIGUOUS(key_scale); + +// CHECK_DTYPE(query, paddle::DataType::INT8); +// CHECK_DTYPE(key, paddle::DataType::INT8); +// CHECK_DTYPE(value, paddle::DataType::FLOAT16); // TODO: there maybe some problem, for bf16 type +// CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); +// CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + +// CHECK_DIMS(query, 4); +// CHECK_DIMS(key, 4); +// CHECK_DIMS(value, 4); +// CHECK_DIMS(output, 4); +// CHECK_DIMS(query_scale, 3); +// CHECK_DIMS(key_scale, 3); + +// const int head_dim = query.shape()[3]; +// const int batch_size = query.shape()[0]; + +// int stride_bz_q = query.strides()[0]; +// int stride_bz_k = key.strides()[0]; +// int stride_bz_v = value.strides()[0]; +// int stride_bz_o = output.strides()[0]; + +// int qo_len, kv_len, num_qo_heads, num_kv_heads; +// int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; +// int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + +// if (tensor_layout == 0) +// { +// qo_len = query.shape()[1]; +// kv_len = key.shape()[1]; +// num_qo_heads = query.shape()[2]; +// num_kv_heads = key.shape()[2]; +// CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); +// CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + +// stride_seq_q = query.strides()[1]; +// stride_seq_k = key.strides()[1]; +// stride_seq_v = value.strides()[1]; +// stride_seq_o = output.strides()[1]; + +// stride_h_q = query.strides()[2]; +// stride_h_k = key.strides()[2]; +// stride_h_v = value.strides()[2]; +// stride_h_o = output.strides()[2]; +// } +// else if (tensor_layout == 1) +// { +// qo_len = query.shape()[2]; +// kv_len = key.shape()[2]; +// num_qo_heads = query.shape()[1]; +// num_kv_heads = key.shape()[1]; +// CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); +// CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + +// stride_seq_q = query.strides()[2]; +// stride_seq_k = key.strides()[2]; +// stride_seq_v = value.strides()[2]; +// stride_seq_o = output.strides()[2]; + +// stride_h_q = query.strides()[1]; +// stride_h_k = key.strides()[1]; +// stride_h_v = value.strides()[1]; +// stride_h_o = output.strides()[1]; +// } +// else +// { +// throw std::invalid_argument("tensor_layout must be 0 or 1"); +// } + +// if (num_qo_heads % num_kv_heads != 0) { +// std::ostringstream err_msg; +// err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; +// throw std::invalid_argument(err_msg.str()); +// } + +// const int num_kv_groups = num_qo_heads / num_kv_heads; + +// paddle::Tensor lse = paddle::empty({1}); +// if (return_lse) +// { +// lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); +// } + +// auto output_dtype = output.dtype(); // in [bfloat16 or float16] + +// DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { +// DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { +// DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { +// DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { +// DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { +// constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; +// constexpr int CTA_K = (HEAD_DIM == 256) ? 32 : 64; +// constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; +// constexpr int WARP_K = (HEAD_DIM == 256) ? 32 : 64; + +// constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + +// if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) +// { +// CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); +// CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); +// } +// else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) +// { +// CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); +// CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); +// } +// else +// { +// static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); +// } + +// // smem_Q smem_K smem_V smem_O +// size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + +// auto kernel_func = qk_int_sv_f16_attn_buffer_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, ComputeUnit::kTensorCore, +// mask_mode, 32, RETURN_LSE>; + +// cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + +// dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); +// dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + +// // using PD216bit = typename pd_cvt::PD16bitTrait<>::DataType; +// using PD16bitRe = typename pd_cvt::PD16bitReTrait::DataType; + +// PD16bitRe* output_data = output.data(); + +// kernel_func<<>>( +// query.data(), +// key.data(), +// reinterpret_cast(value.data()), // reinterpret_cast(reinterpret_cast(value.data())) +// reinterpret_cast(output.data()), +// (RETURN_LSE) ? lse.data() : nullptr, +// query_scale.data(), +// key_scale.data(), +// nullptr, +// qo_len, +// kv_len, +// num_kv_groups, +// stride_bz_q, stride_seq_q, stride_h_q, +// stride_bz_k, stride_seq_k, stride_h_k, +// stride_bz_v, stride_seq_v, stride_h_v, +// stride_bz_o, stride_seq_o, stride_h_o, +// sm_scale); +// }); +// }); +// }); +// }); +// }); + +// return {lse}; +// } + +// qk_int_sv_f16 impl +// the previous one stands for buffer +template +__global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(std::is_same::value || !use_inst_buffer, "use_inst_buffer only supports DTypeSVAccum as float"); + static_assert(std::is_same::value || std::is_same::value, "DTypeSVAccum must be float or half"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(!fuse_v_mean || std::is_same::value, "fuse_v_mean only supports half"); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + constexpr uint32_t V_SMEM_STRIDE = head_dim; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + else if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + +#pragma unroll + for (uint32_t iter = 1; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + // do not apply causal mask and out of bound mask for these iterations + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // check out of bound in the last iter + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + + } + + // TODO: thread block sync mdo state for num_warps_k > 0 + + normalize_d(RO, m, d); + + // save the result + // if (get_warp_idx_k() == 0) + // { + + // convert half to bfloat16 + if constexpr (std::is_same::value && std::is_same::value) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((nv_bfloat162*)RO[fq][fv])[0] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[0])); + ((nv_bfloat162*)RO[fq][fv])[1] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[1])); + ((nv_bfloat162*)RO[fq][fv])[2] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[2])); + ((nv_bfloat162*)RO[fq][fv])[3] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[3])); + } + } + } + + // add v_mean + if constexpr (fuse_v_mean) + { + DTypeOut2 v_mean[2]; + DTypeOut *V_mean_lane_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + lane_id % 4 * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + v_mean[0] = *((DTypeOut2*)(V_mean_lane_ptr + fv * 16)); + v_mean[1] = *((DTypeOut2*)(V_mean_lane_ptr + 8 + fv * 16)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + ((DTypeOut2*)RO[fq][fv])[0] = __hadd2(((DTypeOut2*)RO[fq][fv])[0], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[1] = __hadd2(((DTypeOut2*)RO[fq][fv])[1], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[2] = __hadd2(((DTypeOut2*)RO[fq][fv])[2], v_mean[1]); + ((DTypeOut2*)RO[fq][fv])[3] = __hadd2(((DTypeOut2*)RO[fq][fv])[3], v_mean[1]); + } + } + } + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + if constexpr (std::is_same::value) + { + // convert RO to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); + } + else if constexpr (std::is_same::value) + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); + } + } + + ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + // ! permuted, make sure you know what you are doing + ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; + ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + } + else if constexpr (std::is_same::value) + { + ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((int32_t*)RO[fq][fv])[0]; + ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[1]; + + // ! permuted, make sure you know what you are doing + ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[2]; + ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[3]; + } + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + // ! this only works for num_tiles_q = 2 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); + } + } +} + +// tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] +// impl -> see sageattn.h file +std::vector qk_int8_sv_f16_accum_f32_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT16); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.shape()[3]; + const int batch_size = query.shape()[0]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.strides()[1]; + stride_seq_k = key.strides()[1]; + stride_seq_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + + stride_h_q = query.strides()[2]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_h_o = output.strides()[2]; + } + else if (tensor_layout == 1) + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.strides()[2]; + stride_seq_k = key.strides()[2]; + stride_seq_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + + stride_h_q = query.strides()[1]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_h_o = output.strides()[1]; + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + paddle::Tensor lse = paddle::empty({1}); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu new file mode 100644 index 000000000000..eca067e1aef0 --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu @@ -0,0 +1,1239 @@ +#include + +#include "paddle/extension.h" + +#include "sageattn_utils.cuh" + +#define PACK_SIZE_QK 16 // as if it is int8 +#define PACK_SIZE_V 16 // fp8 +#define PACK_SIZE_O 8 // fp16 + +// treat as if int8 tensor core +#define MMA_QK_M 16 +#define MMA_QK_N 16 +#define MMA_QK_K 32 + +// fp8 tensor core +#define MMA_SV_M 16 +#define MMA_SV_N 16 +#define MMA_SV_K 32 + +// kernel impl +// qk_int_sv_f8 impl +// inner impl +template +__global__ void qk_int_sv_f8_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, int8_t *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, float *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_h_v, const uint32_t stride_d_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(std::is_same::value, "DTypeSVAccum must be float, half is WIP"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(CTA_K % 64 == 0); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + // for fp16: head_dim + constexpr uint32_t V_SMEM_STRIDE = CTA_K; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + else if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + // for fp16: 32 + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + // for fp16: 32 + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 64) ? 4 : 8; + // for fp16: 32 + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + // for fp16: CTA_K + constexpr uint32_t V_smem_iters_col = head_dim / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + // for fp16: CTA_K / num_warps * warp_id * stride_seq_v + lane_id / global_to_shared_line_lanes_V * stride_seq_v + int8_t *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + head_dim / num_warps * warp_id * stride_d_v + lane_id / global_to_shared_line_lanes_V * stride_d_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + // for fp 16: + // uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(lane_id % 8 + (lane_id / 16) * 8, get_warp_idx_k() * WARP_K / PACK_SIZE_V + (lane_id / 8) % 2); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V + // ! we assume that V is padded. If not, there might be illegal memory access or nan issue. + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + +#pragma unroll + for (uint32_t iter = 1; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + __syncthreads(); + + // load K without predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + __syncthreads(); + // load V + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + + __syncthreads(); + // load V + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + + __syncthreads(); + + } + + // TODO: thread block sync mdo state for num_warps_k > 0. Then only one thread block needs to do the final saving. + + normalize_d(RO, m, d); + + // ! here we just implement the case for fp32 acumulation + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + if constexpr (fuse_v_mean) + { + float v_mean[4]; + float *V_mean_base_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_mean)[0] = *((float2*)(V_mean_base_ptr + fv * 16)); + ((float2*)v_mean)[1] = *((float2*)(V_mean_base_ptr + fv * 16 + 8)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] += v_mean[0]; + RO[fq][fv][1] += v_mean[1]; + RO[fq][fv][2] += v_mean[0]; + RO[fq][fv][3] += v_mean[1]; + RO[fq][fv][4] += v_mean[2]; + RO[fq][fv][5] += v_mean[3]; + RO[fq][fv][6] += v_mean[2]; + RO[fq][fv][7] += v_mean[3]; + } + } + } + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + if constexpr (std::is_same::value) + { + // convert RO to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); + } + else + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); + } + } + + ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O) + 1); + ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[2]; + ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + } + else if constexpr (std::is_same::value) + { + // ! need to convert to bf16 if necessary + // ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((int32_t*)RO[fq][fv])[0]; + // ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[1]; + + // offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O) + 1); + // ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((int32_t*)RO[fq][fv])[2]; + // ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[3]; + } + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + // ! this only works for num_tiles_q = 2 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k] - S_FP8_OFFSET); + } + } +} // kernel impl end + +// impl -> see sageattn.h file +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + paddle::Tensor& value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + CHECK_CUDA(value_mean); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + CHECK_CONTIGUOUS(value_mean); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_mean, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + CHECK_DIMS(value_mean, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; + constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; + constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + + assert(value.shape()[0] == batch_size); + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + CHECK_SHAPE(value_mean, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, true>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + reinterpret_cast(value_mean.data()), + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; + constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; + constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + + assert(value.shape()[0] == batch_size); + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; + constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; + constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + + assert(value.shape()[0] == batch_size); + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu new file mode 100644 index 000000000000..84083d543afd --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu @@ -0,0 +1,878 @@ +#include + +#include "paddle/extension.h" + +#include "sageattn_utils.cuh" + +template +CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { + constexpr int smem_stride = BlockMinorSize * sizeof(T); + static_assert(sizeof(T) == 2 || sizeof(T) == 1); + static_assert(smem_stride == 32 || smem_stride == 64 || smem_stride == 128); + + CUtensorMap tma_map; + void* gmem_address = (void*)gmem_ptr; + uint64_t gmem_prob_shape[5] = {(uint64_t)d4, (uint64_t)d3, (uint64_t)d2, (uint64_t)d1, 1}; + uint64_t gmem_prob_stride[5] = {(uint64_t)stride3 * sizeof(T), (uint64_t)stride2 * sizeof(T), (uint64_t)stride1 * sizeof(T), 0, 0}; + uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; + uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; + + CUresult result = cuTensorMapEncodeTiled( + &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, + gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, + (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, + promotion_mode, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + assert(result == CUDA_SUCCESS); + + return tma_map; +} + +__device__ __forceinline__ void init_barrier(uint64_t* bar, int thread_count) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count) + ); +} + +template +__device__ __forceinline__ void expect_bytes(uint64_t* bar) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "n"(bytes)); +} + +template +__device__ __forceinline__ void load_async_4D(T *dst, void const* const src_tma_map, uint64_t* bar, int s0, int s1, int s2, int s3) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(s0), "r"(s1), "r"(s2), "r"(s3) + : "memory" + ); +} + +template +__device__ __forceinline__ void store_async_4D(void const* dst_tma_map, T *src, int global_token_idx, int global_head_idx, int global_batch_idx) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(global_token_idx), "r"(global_head_idx), "r"(global_batch_idx) + : "memory" + ); +} + +__device__ __forceinline__ void wait(uint64_t* bar, int kPhaseBit) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +template +__device__ __forceinline__ void arrive(uint64_t* bar) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "n"(count) + : "memory" + ); +} + +// +// ======= kernel impl ======= +// + +template +__global__ void qk_int8_sv_f8_attn_kernel(const __grid_constant__ CUtensorMap tensorMapQ, + const __grid_constant__ CUtensorMap tensorMapK, + const __grid_constant__ CUtensorMap tensorMapV, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, + DTypeOut* O, uint32_t stride_bz_o, uint32_t stride_h_o, uint32_t stride_seq_o, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + float sm_scale) +{ + static_assert(NUM_THREADS == 128); + static_assert(CTA_Q <= CTA_K); + + const uint32_t warp_idx = (threadIdx.x % 128) / 32; + const uint32_t lane_id = threadIdx.x % 32; + + constexpr uint32_t num_tiles_q = CTA_Q / 64; + constexpr uint32_t num_tiles_k = CTA_K / 16; + constexpr uint32_t num_tiles_qk_inner = head_dim / 32; + constexpr uint32_t num_tiles_v = head_dim / 16; + constexpr uint32_t num_tiles_pv_inner = CTA_K / 32; + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t head_id = blockIdx.y; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t kv_head_id = head_id / num_kv_groups; + + sm_scale *= math::log2e; + + extern __shared__ __align__(128) int8_t smem_[]; + + int8_t *sQ = (int8_t*)smem_; + int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); + int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + half *sO = (half*)smem_; + + int32_t RS[num_tiles_q][num_tiles_k][8]; + float RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; + float d[num_tiles_q][2]; + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * 4 + warp_idx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (4 * 8) + warp_idx * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_block_k * 4) + (head_id / num_kv_groups) * (num_block_k * 4) + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) ? 1 : 4; + + uint32_t Q_idx_lane_base = bx * CTA_Q + warp_idx * 16 + lane_id / 4; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + m[fq][0] = -5000000.0f; + m[fq][1] = -5000000.0f; + d[fq][0] = 1.0f; + d[fq][1] = 1.0f; + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + } + + __shared__ __align__(8) uint64_t barrier_Q; + __shared__ __align__(8) uint64_t barrier_K; + __shared__ __align__(8) uint64_t barrier_V; + + if (threadIdx.x == 0) { + init_barrier(&barrier_Q, 1); + init_barrier(&barrier_K, 1); + init_barrier(&barrier_V, 1); + } + + __syncthreads(); + + // load Q, K, V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_Q * head_dim) * sizeof(int8_t)>(&barrier_Q); + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sQ, &tensorMapQ, &barrier_Q, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, 0, kv_head_id, batch_id); + load_async_4D(sV, &tensorMapV, &barrier_V, 0, 0, kv_head_id, batch_id); + } + + float q_scale = Q_scale[q_scale_idx]; + float original_sm_scale = sm_scale; + + // wait for Q + wait(&barrier_Q, 0); + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + int p = 1; + for (uint32_t iter = 1; iter < num_iterations; iter++) + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (iter - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // wait for K + wait(&barrier_K, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // load K + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, iter * CTA_K, kv_head_id, batch_id); + } + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + + // load V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sV, &tensorMapV, &barrier_V, iter * CTA_K, 0, kv_head_id, batch_id); + } + } + + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale; + + // wait for K + wait(&barrier_K, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + // masking +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 64 + 8 * ((k % 4) / 2); + const uint32_t k_idx = (num_iterations - 1) * CTA_K + fk * 16 + 2 * (lane_id % 4) + 8 * (k / 4) + k % 2; + + bool is_out_of_bounds; + + if constexpr (mask_mode == MaskMode::kCausal) + { + is_out_of_bounds = (k_idx > q_idx) || (k_idx >= kv_len); + } + else + { + is_out_of_bounds = (k_idx >= kv_len); + } + + if (is_out_of_bounds) + { + RS_f32[fq][fk][k] = -5000000.0f; + } + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + } + + normalize_d(RO, m, d); + + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; + #pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); + + #pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + warp_idx * 16 + (lane_id / 4)) * stride_seq_o + (lane_id % 4) * 2 ; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < head_dim/16; fv++) + { + if (Q_idx_lane_base + fq * 64 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[0]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[2]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[0]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[2]); + } + } + + if (Q_idx_lane_base + fq * 64 + 8 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[1]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[3]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[1]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[3]); + } + } + } + } +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_type = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_type, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, false>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, true>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh new file mode 100644 index 000000000000..49e0fb96893a --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh @@ -0,0 +1,2671 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "paddle/extension.h" + +// currently we do not support INT4, so we implement INT8 sage attention inference temperarily + +#define FINAL_MASK 0xffffffff +#define WARP_SIZE 32 + +#define S_FP8_OFFSET 8.807f +#define S_FP8_OFFSET_EXP 6680.8477f +#define S_FP8_OFFSET_EXP_INV 0.0022326917f + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define FP8_CAST_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +// dispatch_utils.h +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + __VA_ARGS__ \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head dim: " << int(head_dim); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ + if (is_causal == 1) { \ + constexpr bool IS_CAUSAL = true; \ + __VA_ARGS__ \ + } else if (is_causal == 0) { \ + constexpr bool IS_CAUSAL = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(is_causal); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ + if (qk_quant_gran == 2) { \ + constexpr int QK_QUANT_GRAN = 2; \ + __VA_ARGS__ \ + } else if (qk_quant_gran == 3) { \ + constexpr int QK_QUANT_GRAN = 3; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \ + if (return_lse == 1) { \ + constexpr bool RETURN_LSE = true; \ + __VA_ARGS__ \ + } else if (return_lse == 0) { \ + constexpr bool RETURN_LSE = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(return_lse); \ + throw std::invalid_argument(err_msg.str()); \ + } + +// DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 +// here we will use paddle's DataType +#define DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(paddle_dtype, c_type, ...) \ + if (paddle_dtype == paddle::DataType::FLOAT16) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (paddle_dtype == paddle::DataType::BFLOAT16) { \ + using c_type = nv_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << paddle_dtype; \ + PD_CHECK(false, oss.str()); \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr int BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } else if (block_size == 128) { \ + constexpr int BLOCK_SIZE = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported block_size " << int(block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \ + if (warp_block_size == 16) { \ + constexpr int WARP_BLOCK_SIZE = 16; \ + __VA_ARGS__ \ + } else if (warp_block_size == 32) { \ + constexpr int WARP_BLOCK_SIZE = 32; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported warp_block_size " << int(warp_block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } + +// define the macro for necessary checks, originally in `utils.cuh` +#define CHECK_CUDA(x) \ + PD_CHECK(x.is_gpu(), "Tensor " #x " must be on CUDA") // shift to paddle API: is_gpu() + +// CHECK_DTYPE aims at testing the tensor datatype, use paddle::DataType +#define CHECK_DTYPE(x, true_dtype) \ + PD_CHECK(x.dtype() == true_dtype, \ + "Tensor " #x " must have dtype (" #true_dtype ")") // DataType dtype() const; +#define CHECK_DIMS(x, true_dim) \ + PD_CHECK(x.dims().size() == true_dim, \ + "Tensor " #x " must have dimension number (" #true_dim ")") // paddle API: .dims().size() +#define CHECK_NUMEL(x, minimum) \ + PD_CHECK(x.numel() >= minimum, \ + "Tensor " #x " must have at last " #minimum " elements") +#define CHECK_SHAPE(x, ...) \ + PD_CHECK(x.dims() == common::DDim({__VA_ARGS__}), \ + "Tensor " #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) \ + PD_CHECK(x.is_contiguous(), "Tensor " #x " must be contiguous") // TODO: check if valid +#define CHECK_LASTDIM_CONTIGUOUS(x) \ + PD_CHECK(x.strides().at(x.strides().size() - 1) == 1, \ + "Tensor " #x " must be contiguous at the last dimension") + + +namespace sageattn { + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMin(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = min(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the minimum of all elements in a block */ +template +__inline__ __device__ T blockReduceMin(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMin(val); // get minx in each warp + if (lane == 0) // record in-warp minx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : 1e20f; + val = warpReduceMin(val); + return val; +} + +} // namespace sageattn + +namespace mma{ + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define MMA_F16F16F32_M16N8K16_ENABLED +#define MMA_F16F16F16_M16N8K16_ENABLED +#define MMA_S8S8S32_M16N8K32_ENABLED +#define MMA_S4S4S32_M16N8K64_ENABLED +#endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define MMA_F16F16F32_M16N8K8_ENABLED +#define MMA_F16F16F16_M16N8K8_ENABLED +#define LDMATRIX_M8N8X2_ENABLED +#define LDMATRIX_M8N8X4_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define MMA_F8F8F32_M16N8K16_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x2 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x2(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X2_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from + * shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k64 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major fp8 e4m3 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major fp8 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f16f16f32(float* d, uint32_t* s) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(1006648320), // 1006648320 packs two 1.0f in half precision + "r"(1006648320), "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f8f8f32(float* d, uint32_t* s) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(943208504), "r"(943208504), // 943208504 packs four 1.0f in e4m3 + "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +} // namespace mma + + +// namespace cp_async +// intend to wrap the copy asynchronizely operations +namespace cp_async { + +enum class SharedMemFillMode { + kFillZero, // Fill zero to shared memory when predicate is false + kNoFill // Do not fill zero to shared memory when predicate is false +}; + +enum class PrefetchMode { + kNoPrefetch, // Do not fetch additional data from global memory to L2 + kPrefetch // Fetch additional data from global memory to L2 +}; + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define CP_ASYNC_ENABLED +#endif +#endif + +/*! + * \brief Wrapper of PTX cp.async.commit_group instruction, commit all prior uncommitted + * cp.async instructions to a group + */ +__device__ __forceinline__ void commit_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.wait_group instruction + * \tparam n Wait till most recent n groups are committed + */ +template +__device__ __forceinline__ void wait_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + */ +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +#else + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory with predicate. + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam fill_mode Whether to fill zero to shared memory when predicate is false + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + * \param predicate Predicate value + * \note fill zero is slower than not fill zero + */ +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } + } +#else + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } else { + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0); + } + } +#endif +} + +} // namespace cp_async + + +// namespace math +// math operations using ptx +namespace math { + +// log2(e) +constexpr float log2e = 1.44269504088896340736f; +constexpr float log2e_recp = 1.0f / log2e; + +__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)&x; } + +__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; } + +/*! + * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ float ptx_exp2(float x) { + float y; + asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) + * \param x input + */ +__forceinline__ __device__ float ptx_log2(float x) { + float y; + asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ half2 ptx_exp2(half2 x) { + uint32_t y_u32; + uint32_t x_u32 = half2_as_uint32(x); + asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); + return uint32_as_half2(y_u32); +} + +/*! + * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ half ptx_exp2(half x) { + ushort y_u16; + asm volatile("ex2.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); + return __ushort_as_half(y_u16); +} + +/*! + * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x + * \param x input + */ +__forceinline__ __device__ float ptx_rcp(float x) { + float y; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly shuffle + * between threads in a warp. + * \param x The value in the source lane + * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ delta] + */ +__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { + float y; + asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" + : "=f"(y) + : "f"(x), "r"(lane_mask)); + return y; +} + +/*! + * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a butterfly + * shuffle between threads in a warp. + * \param x The value in the source lane + * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ lane_mask] + */ +__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { + return __shfl_xor_sync(0xffffffff, x, lane_mask); +} + +/*! + * \brief Wrapper of PTX rsqrt approximation instruction, which computes 1/sqrt(x) + * \param x input + */ +__forceinline__ __device__ float rsqrt(float x) { + float y; + asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ float tanh(float x) { + float y; + asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half2 tanh(half2 x) { + uint32_t y_u32; + uint32_t x_u32 = half2_as_uint32(x); + asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); + return uint32_as_half2(y_u32); +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half tanh(half x) { + ushort y_u16; + asm volatile("tanh.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); + return __ushort_as_half(y_u16); +} + +} // namespace math + + +// originally in `permuted_smem.cuh` +enum class SwizzleMode { + k32B, // for k32B mode, a line of shared memory must have 32B (16 half value) + k64B, // for k64B mode, a line of shared memory must have 64B (32 half value) + k128B, // 128B already spans all banks in shared memory. a line of shared memory can have multiple 128B. +}; + +// Use 128bit as the granularity to fetch/store data per thread to maximize memory bandwidth +using b128_t = uint4; + +/*! + * \brief A stateless shared memory wrapper that uses templates to avoid runtime conditionals. It makes sure + * that access to consecutive rows idx in the same column idx will make full use of the shared memory bank through + * permutation in the granularity of 128bit. + * + * This struct treats all offsets to be the number of `b128_t` elements. It is designed to be stateless, + * meaning it does not maintain any information about the current pointer position. The offset returnd by + * the struct can be used to access the shared memory through the provided interface. + * + * The struct guarantees that the read to permuted offset (i, j) will be the value stored in permuted offset (i, j). + * We assume that shared memory operation operates on at least two consecutive 128-bit values in a row within a warp. + * Under this assumption, we do not permute for k32B mode. + */ +template +struct smem_t { + // The base pointer. + b128_t* base; + // How many b128_t value a row contains + // uint32_t stride; + + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(stride % 8 == 0, "Stride must be multiple of 8 for 128B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(stride == 4, "Stride must be 4 for 64B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(stride == 2, "Stride must be 2 for 32B swizzle mode"); + } else { + static_assert(swizzle_mode != swizzle_mode, "Unsupported swizzle mode"); + } + } + + /*! + * \brief Set the base pointer. + */ + template + __device__ __forceinline__ void set_base(T* new_base) { + base = (b128_t*)new_base; + } + + /*! + * \brief Compute the element offset given coordinates in a permuted shared memory. + * \param i The row index. + * \param j The column index. + */ + static __device__ __forceinline__ uint32_t get_permuted_offset(const uint32_t &i, const uint32_t &j) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + return i * stride + (j ^ (i % 8)); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + return i * stride + (j ^ ((i / 2) % 4)); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return i * stride + j; + } + } + + /*! + * \tparam step_size The step size to advance the offset in the permuted shared memory. + * \param offset The current offset. + */ + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size % 8 == 0, + "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4, "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } + } + + // ! use with care + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset, const uint32_t &step_idx) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size; + } + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_row(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x2) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size * stride; + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x2(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x2(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr, bool predicate) const { + b128_t* smem_ptr = base + offset; + cp_async::pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr) const { + b128_t* smem_ptr = base + offset; + cp_async::load_128b(smem_ptr, reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(const uint32_t &offset, T* gptr) const { + *reinterpret_cast(gptr) = *(base + offset); + } +}; + + +// numeric conversion +__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source0[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source1[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + + +// attn_utils +enum class MaskMode { + kNone = 0, + kCausal = 1, +}; + +// we do not use paddle::DataType, because paddle or torch do not support INT4 natively. We now use our defined DataType. +// To avoid conflict with paddle DataType or something, we rename the `DataType` here to `SADataType` +enum class SADataType { + kHalf, + kInt8, + kInt4, + kE4M3, + kE5M2, +}; + +enum class QuantGranularity { + kPerTensor = 0, + kPerBlock = 1, + kPerWarp = 2, + kPerThread = 3, +}; + +enum class ComputeUnit { + kTensorCore, + kCudaCore, +}; + +inline __device__ __host__ size_t div_ceil(size_t a, size_t b) { + return (a % b != 0) ? (a / b + 1) : (a / b); +} + +__device__ __forceinline__ uint32_t get_warp_id() +{ + return threadIdx.y; +} + +__device__ __forceinline__ uint32_t get_lane_id() +{ + return threadIdx.x; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_q() +{ + return get_warp_id() / num_warps_k; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_k() +{ + return get_warp_id() % num_warps_k; +} + +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +// with predicate +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem, uint32_t base_idx, uint32_t max_len) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr, base_idx < max_len); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + base_idx += global_to_shared_copy_lines_per_warp_per_iter; + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void load_fp8_V_global_to_share(int8_t **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + constexpr uint32_t pack_size_fp8 = 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size_fp8); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size_fp8)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + // for QK: *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; + *lane_ptr += CTA; // ! prevent underflow + *lane_ptr -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_Q, const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t &offset_Q, uint32_t &offset_K) +{ + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4); + + uint32_t RQ[num_tiles_q][4]; + uint32_t RK[4]; + + // the first iteration, mma mode is kInit +#pragma unroll + for (uint32_t iter = 0; iter < 1; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + // ! using permutation invariance + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == SADataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == SADataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + // following iteration, mma mode is kInplace +#pragma unroll + for (uint32_t iter = 1; iter < num_tiles_qk_inner; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == SADataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == SADataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + offset_Q -= (2 * num_tiles_qk_inner); + offset_K -= (2 * num_tiles_qk_inner); +} + +// for case when num_tiles_qk_inner = 1 +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t RQ[][4], uint32_t offset_K) +{ + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4); + static_assert(num_tiles_qk_inner == 1); + + uint32_t RK[4]; + + // mma mode is kInit +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == SADataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == SADataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } +} + +template +__device__ __forceinline__ void apply_causal_mask(const uint32_t &Q_idx_lane_base, const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 16 + 8 * ((k % 4) / 2); + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx > q_idx); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +template +__device__ __forceinline__ void apply_out_of_bound_mask(const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8], const uint32_t &kv_len) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx >= kv_len); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +// for DTypeQKAccum float +template +__device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAccum RO[][num_tiles_v][8], float m[][2], float d[][2], const float &sm_scale) +{ + static_assert(std::is_same::value || (!use_half_o_scale)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // assign the smallest value possible + float m_prev = m[fq][k]; + float m_temp = -5000000.0f; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + float m_local = max(max(RS[fq][fk][k * 2 + 0], RS[fq][fk][k * 2 + 1]), + max(RS[fq][fk][k * 2 + 4], RS[fq][fk][k * 2 + 5])); + m_temp = max(m_temp, m_local); + } + // exchange element with the 4 threads in the row + if constexpr (!fuse_scale) + { + m_temp *= sm_scale; + } + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x1)); // 0 exchange with 1, 2 exchange with 3 + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x2)); // 0 exchange with 2, 1 exchange with 3 + + m[fq][k] = max(m[fq][k], m_temp); + + float o_scale = math::ptx_exp2(m_prev - m[fq][k]); + + // update denominator + d[fq][k] *= o_scale; + + half2 o_scale2 = __floats2half2_rn(o_scale, o_scale); + + // update RO +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k * 2 + 0] *= o_scale; + RO[fq][fv][k * 2 + 1] *= o_scale; + RO[fq][fv][k * 2 + 4] *= o_scale; + RO[fq][fv][k * 2 + 5] *= o_scale; + } + else if constexpr (std::is_same::value) + { + if constexpr (use_half_o_scale) + { + ((half2*)RO[fq][fv])[k] = __hmul2(((half2*)RO[fq][fv])[k], o_scale2); + ((half2*)RO[fq][fv])[k + 2] = __hmul2(((half2*)RO[fq][fv])[k + 2], o_scale2); + } + else + { + RO[fq][fv][k * 2 + 0] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 0]) * o_scale); + RO[fq][fv][k * 2 + 1] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 1]) * o_scale); + RO[fq][fv][k * 2 + 4] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 4]) * o_scale); + RO[fq][fv][k * 2 + 5] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 5]) * o_scale); + } + } + } + + // raise RS to exponent + float negative_m = -m[fq][k]; + if constexpr (exp_offset) + { + negative_m += S_FP8_OFFSET; // times 400 to achieve smaller quantization error of fp8 S + } +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (fuse_scale) + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(RS[fq][fk][k * 2 + 0] + negative_m); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(RS[fq][fk][k * 2 + 1] + negative_m); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(RS[fq][fk][k * 2 + 4] + negative_m); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(RS[fq][fk][k * 2 + 5] + negative_m); + } + else + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 0], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 1], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 4], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 5], sm_scale, negative_m)); + } + } + } + } +} + +template +__device__ __forceinline__ void RS_32_to_16(T RS[][num_tiles_k][8], uint32_t RS_16[][num_tiles_k][4]) +{ + static_assert(sizeof(T) == 4); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + ((half2*)RS_16[fq][fk])[0] = __float22half2_rn(((float2*)RS[fq][fk])[0]); + ((half2*)RS_16[fq][fk])[1] = __float22half2_rn(((float2*)RS[fq][fk])[1]); + ((half2*)RS_16[fq][fk])[2] = __float22half2_rn(((float2*)RS[fq][fk])[2]); + ((half2*)RS_16[fq][fk])[3] = __float22half2_rn(((float2*)RS[fq][fk])[3]); + } + } +} + +template +__device__ __forceinline__ void RS_32_to_8(float RS[][num_tiles_k][8], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + floatx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 2, RS[fq][fk * 2 + 0] + 6); + floatx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 2, RS[fq][fk * 2 + 1] + 6); + } + } +} + +template +__device__ __forceinline__ void RS_16_to_8(uint32_t RS[][num_tiles_k][4], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + halfx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3); + halfx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3); + } + } +} + +template +__device__ __forceinline__ void RS_8_to_16(uint32_t RS_8[][num_tiles_k / 2][4], uint32_t RS[][num_tiles_k][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2, RS_8[fq][fk]); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3, RS_8[fq][fk] + 1); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2, RS_8[fq][fk] + 2); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3, RS_8[fq][fk] + 3); + } + } +} + +template +__device__ __forceinline__ void accumulate_d(T RS[][num_tiles_k][(compute_unit == ComputeUnit::kTensorCore)? 4 : 8], float d[][2]) +{ + // for compute unit cuda core, RS is float + // for compute unit tensor core, RS is packed half + static_assert((std::is_same::value && compute_unit == ComputeUnit::kCudaCore) || + (std::is_same::value && compute_unit == ComputeUnit::kTensorCore)); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (compute_unit == ComputeUnit::kTensorCore) + { + // full accumulate with tensor core + mma::rowsum_f16f16f32(d[fq], (uint32_t*)(RS[fq][fk])); + } + else if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // partial accumulate with cuda core + d[fq][0] += RS[fq][fk][0] + RS[fq][fk][1] + RS[fq][fk][4] + RS[fq][fk][5]; + d[fq][1] += RS[fq][fk][2] + RS[fq][fk][3] + RS[fq][fk][6] + RS[fq][fk][7]; + } + } + } +} + +template +__device__ __forceinline__ void accumulate_d_f8(uint32_t RS[][num_tiles_k / 2][4], float d[][2]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + mma::rowsum_f8f8f32(d[fq], RS[fq][fk]); + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv(const smem_t &smem_V, uint32_t RS_f16[][num_tiles_k][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_warp_idx_k() * (num_tiles_k * 16) + get_lane_id() % 16; + uint32_t smem_V_col_base = get_lane_id() / 16; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fk * 16, smem_V_col_base + fv * 2); + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], RS_f16[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], RS_f16[fq][fk], RV); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted_inst_buf(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + static_assert(std::is_same::value); + + uint32_t RO_inst_buf[num_tiles_q][num_tiles_v][4]; + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // accumulate into RO +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].x); + RO[fq][fv][1] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].y); + RO[fq][fv][2] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].x); + RO[fq][fv][3] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].y); + RO[fq][fv][4] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].x); + RO[fq][fv][5] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].y); + RO[fq][fv][6] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].x); + RO[fq][fv][7] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].y); + } + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void normalize_d(DTypeSVAccum RO[][num_tiles_v][8], DTypeQKAccum m[][2], float d[][2]) +{ + if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // accumulate_d performs partial accumulation with cuda core + // aggregate d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x1); // sum 0 and 1, 2 and 3 + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x2); // sum 0 and 2, 1 and 3 + } + } + } + + // divide O by d + float d_rcp[num_tiles_q][2]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // TODO: check m to prevent nan + d_rcp[fq][k] = math::ptx_rcp(d[fq][k]); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k] *= d_rcp[fq][(k % 4) / 2]; + } + else if constexpr (std::is_same::value) + { + RO[fq][fv][k] = __float2half_rn(__half2float(RO[fq][fv][k]) * d_rcp[fq][(k % 4) / 2]); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv_inst_buf(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; + + float RO_inst_buf[num_tiles_q][num_tiles_v][8]; + +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += RO_inst_buf[fq][fv][0]; + RO[fq][fv][1] += RO_inst_buf[fq][fv][1]; + RO[fq][fv][2] += RO_inst_buf[fq][fv][2]; + RO[fq][fv][3] += RO_inst_buf[fq][fv][3]; + RO[fq][fv][4] += RO_inst_buf[fq][fv][4]; + RO[fq][fv][5] += RO_inst_buf[fq][fv][5]; + RO[fq][fv][6] += RO_inst_buf[fq][fv][6]; + RO[fq][fv][7] += RO_inst_buf[fq][fv][7]; + } + } +} + +// paddle converter zone +namespace pd_cvt { + +// phi::dtype::xx16 -> half or nv_bfloat16 +template +struct PD16bitTrait { + using DataType = T; +}; + +template <> +struct PD16bitTrait { + // Since LayerNormDirectCUDAFunctor register half type, we need to convert + // phi::float16 to half. + using DataType = half; +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +class PD16bitTrait { +public: + using DataType = __nv_bfloat16; +}; +#endif + +// half or nv_bfloat16 -> phi::dtype::xx16 +template +struct PD16bitReTrait { + using DataType = T; +}; + +template <> +struct PD16bitReTrait { + using DataType = phi::dtype::float16; +}; + +#ifdef PADDLE_CUDA_BF16 +template<> +class PD16bitReTrait<__nv_bfloat16> { +public: + using DataType = phi::dtype::bfloat16; +}; +#endif + +}; // paddle converter zone end + +// namespace wgmma +namespace wgmma{ +__device__ __forceinline__ uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } + +template +__device__ uint64_t make_smem_desc(T* ptr) { + static_assert(stride == 32 || stride == 64 || stride == 128); + uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); + uint64_t desc = 0x0000000000000000; + desc |= matrix_descriptor_encode(addr); + desc |= matrix_descriptor_encode((uint64_t)16) << 16; + desc |= matrix_descriptor_encode((uint64_t)(8 * stride)) << 32; + desc |= ((stride == 128) ? 1llu : (stride == 64) ? 2llu : 3llu) << 62; + return desc; +} + +__device__ __forceinline__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__device__ __forceinline__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +template +__device__ __forceinline__ void warpgroup_wait() { + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); +} + +template +__device__ __forceinline__ void wgmma_m64n128k16_f16f16f32(float d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " %66, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransA)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k16_f16f16f32(float d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + " %32," + " %33," + " %34, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransA)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n128k16_f16f16f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + "{%64, %65, %66, %67}, " + " %68," + " %69, %70, %71, %72;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k16_f16f16f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + "{%32, %33, %34, %35}, " + " %36," + " %37, %38, %39, %40;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k32_f8f8f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + "{%32, %33, %34, %35}, " + " %36," + " %37," + " %38, %39;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), + "n"(1), "n"(1)); +} + +template +__device__ __forceinline__ void wgmma_m64n128k32_f8f8f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + "{%64, %65, %66, %67}, " + " %68," + " %69," + " %70, %71;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), + "n"(1), "n"(1)); +} + +template +__device__ void wgmma_m64n128k32_s8s8s32(int32_t d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " %66;\n" + "}\n" + : "+r"(d[0][0]), "+r"(d[0][1]), "+r"(d[0][2]), "+r"(d[0][3]), "+r"(d[0][4]), "+r"(d[0][5]), "+r"(d[0][6]), "+r"(d[0][7]), + "+r"(d[1][0]), "+r"(d[1][1]), "+r"(d[1][2]), "+r"(d[1][3]), "+r"(d[1][4]), "+r"(d[1][5]), "+r"(d[1][6]), "+r"(d[1][7]), + "+r"(d[2][0]), "+r"(d[2][1]), "+r"(d[2][2]), "+r"(d[2][3]), "+r"(d[2][4]), "+r"(d[2][5]), "+r"(d[2][6]), "+r"(d[2][7]), + "+r"(d[3][0]), "+r"(d[3][1]), "+r"(d[3][2]), "+r"(d[3][3]), "+r"(d[3][4]), "+r"(d[3][5]), "+r"(d[3][6]), "+r"(d[3][7]), + "+r"(d[4][0]), "+r"(d[4][1]), "+r"(d[4][2]), "+r"(d[4][3]), "+r"(d[4][4]), "+r"(d[4][5]), "+r"(d[4][6]), "+r"(d[4][7]), + "+r"(d[5][0]), "+r"(d[5][1]), "+r"(d[5][2]), "+r"(d[5][3]), "+r"(d[5][4]), "+r"(d[5][5]), "+r"(d[5][6]), "+r"(d[5][7]), + "+r"(d[6][0]), "+r"(d[6][1]), "+r"(d[6][2]), "+r"(d[6][3]), "+r"(d[6][4]), "+r"(d[6][5]), "+r"(d[6][6]), "+r"(d[6][7]), + "+r"(d[7][0]), "+r"(d[7][1]), "+r"(d[7][2]), "+r"(d[7][3]), "+r"(d[7][4]), "+r"(d[7][5]), "+r"(d[7][6]), "+r"(d[7][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD))); +} + +template +__device__ void wgmma_m64n64k32_s8s8s32(int32_t d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + "}\n" + : "+r"(d[0][0]), "+r"(d[0][1]), "+r"(d[0][2]), "+r"(d[0][3]), "+r"(d[0][4]), "+r"(d[0][5]), "+r"(d[0][6]), "+r"(d[0][7]), + "+r"(d[1][0]), "+r"(d[1][1]), "+r"(d[1][2]), "+r"(d[1][3]), "+r"(d[1][4]), "+r"(d[1][5]), "+r"(d[1][6]), "+r"(d[1][7]), + "+r"(d[2][0]), "+r"(d[2][1]), "+r"(d[2][2]), "+r"(d[2][3]), "+r"(d[2][4]), "+r"(d[2][5]), "+r"(d[2][6]), "+r"(d[2][7]), + "+r"(d[3][0]), "+r"(d[3][1]), "+r"(d[3][2]), "+r"(d[3][3]), "+r"(d[3][4]), "+r"(d[3][5]), "+r"(d[3][6]), "+r"(d[3][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD))); +} + +template +__device__ __forceinline__ void wgmma_f16f16f32(float d[WGMMA_N/16][8], T* sA, T* sB) { + static_assert(std::is_same::value); + + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k16_f16f16f32(d, sA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k16_f16f16f32(d, sA, sB); + } +} + +template +__device__ __forceinline__ void wgmma_s8s8s32(int32_t d[WGMMA_N/16][8], T* sA, T* sB) { + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k32_s8s8s32(d, sA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k32_s8s8s32(d, sA, sB); + } +} + +template +__device__ __forceinline__ void wgmma_f8f8f32(float d[][8], uint32_t* RA, T* sB) { + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k32_f8f8f32(d, RA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k32_f8f8f32(d, RA, sB); + } +} + +} // namespace wgmma \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 6e0ce8e20658..07d257c938a4 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -163,6 +163,16 @@ def get_gencode_flags(): "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", ] +if cc >= 89 and cuda_version >= 12.4: + nvcc_compile_args += [ + "-std=c++17", + "--use_fast_math", + "--threads=8", + "-D_GLIBCXX_USE_CXX11_ABI=1", + ] + sources += find_end_files("./gpu/sage_attn_kernels", ".cu") + sources += ["./gpu/sage_attn_kernels/sageattn.cc"] + if cc >= 90 and cuda_version >= 12.0: nvcc_compile_args += ["-DNDEBUG"] os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py --cuda_arch 90") @@ -178,7 +188,7 @@ def get_gencode_flags(): name="paddlenlp_ops", ext_modules=CUDAExtension( sources=sources, - extra_compile_args={"cxx": ["-O3"], "nvcc": nvcc_compile_args}, + extra_compile_args={"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], "nvcc": nvcc_compile_args}, libraries=["cublasLt"], library_dirs=[library_path], ), From ca60579964c347fa41d752b999ea39cda53da9cd Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Wed, 12 Feb 2025 18:54:50 +0800 Subject: [PATCH 02/18] fix --- csrc/setup_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 07d257c938a4..12a36c1e306b 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -136,6 +136,7 @@ def get_gencode_flags(): cc = get_sm_version() cuda_version = float(paddle.version.cuda()) +cuda_version = 12.4 if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] From 7bb2c795d8d02a1382a39d2d617bde9f39837959 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Wed, 19 Feb 2025 18:44:35 +0800 Subject: [PATCH 03/18] add ds sageattn kernel --- .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 952 ++++++++++++++++++ 1 file changed, 952 insertions(+) create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu new file mode 100644 index 000000000000..dd843da22d0b --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -0,0 +1,952 @@ +#include +#include +#include "paddle/extension.h" + +#include "sageattn_utils.cuh" + +template +CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { + constexpr int smem_stride = BlockMinorSize * sizeof(T); + static_assert(sizeof(T) == 2 || sizeof(T) == 1); + static_assert(smem_stride == 32 || smem_stride == 64 || smem_stride == 128); + + CUtensorMap tma_map; + void* gmem_address = (void*)gmem_ptr; + uint64_t gmem_prob_shape[5] = {(uint64_t)d4, (uint64_t)d3, (uint64_t)d2, (uint64_t)d1, 1}; + uint64_t gmem_prob_stride[5] = {(uint64_t)stride3 * sizeof(T), (uint64_t)stride2 * sizeof(T), (uint64_t)stride1 * sizeof(T), 0, 0}; + uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; + uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; + + CUresult result = cuTensorMapEncodeTiled( + &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, + gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, + (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, + promotion_mode, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + assert(result == CUDA_SUCCESS); + + return tma_map; +} + +__device__ __forceinline__ void init_barrier(uint64_t* bar, int thread_count) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count) + ); +} + +template +__device__ __forceinline__ void expect_bytes(uint64_t* bar) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "n"(bytes)); +} + +template +__device__ __forceinline__ void load_async_4D(T *dst, void const* const src_tma_map, uint64_t* bar, int s0, int s1, int s2, int s3) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(s0), "r"(s1), "r"(s2), "r"(s3) + : "memory" + ); +} + +template +__device__ __forceinline__ void store_async_4D(void const* dst_tma_map, T *src, int global_token_idx, int global_head_idx, int global_batch_idx) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(global_token_idx), "r"(global_head_idx), "r"(global_batch_idx) + : "memory" + ); +} + +__device__ __forceinline__ void wait(uint64_t* bar, int kPhaseBit) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +template +__device__ __forceinline__ void arrive(uint64_t* bar) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "n"(count) + : "memory" + ); +} + +// +// ======= kernel impl ======= +// + +template +__global__ void qk_int8_sv_f8_attn_dsk_kernel(const __grid_constant__ CUtensorMap tensorMapQ, + const __grid_constant__ CUtensorMap tensorMapK, + const __grid_constant__ CUtensorMap tensorMapQ_pe, + const __grid_constant__ CUtensorMap tensorMapK_pe, + const __grid_constant__ CUtensorMap tensorMapV, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, + DTypeOut* O, uint32_t stride_bz_o, uint32_t stride_h_o, uint32_t stride_seq_o, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + float sm_scale) +{ + static_assert(NUM_THREADS == 128); + static_assert(CTA_Q <= CTA_K); + + const uint32_t warp_idx = (threadIdx.x % 128) / 32; + const uint32_t lane_id = threadIdx.x % 32; + + constexpr uint32_t num_tiles_q = CTA_Q / 64; + constexpr uint32_t num_tiles_k = CTA_K / 16; + constexpr uint32_t num_tiles_qk_inner = head_dim / 32; + constexpr uint32_t num_tiles_qk_pe_inner = head_dim_pe / 32; + constexpr uint32_t num_tiles_v = head_dim / 16; + constexpr uint32_t num_tiles_pv_inner = CTA_K / 32; + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t head_id = blockIdx.y; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t kv_head_id = head_id / num_kv_groups; + + sm_scale *= math::log2e; + + extern __shared__ __align__(128) int8_t smem_[]; + + /* // original: + * int8_t *sQ = (int8_t*)smem_; + * int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); + * int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + * half *sO = (half*)smem_; + */ + + int8_t *sQ = (int8_t*)smem_; // 0 + int8_t *sQ_pe = (int8_t*)(smem_ + CTA_Q * (head_dim) * sizeof(int8_t)); // 0 + head_dim + + int8_t *sK = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t)); // 0 + head_dim + pe + int8_t *sK_pe = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t) + CTA_K * (head_dim) * sizeof(int8_t)); // 0 + head_dim + pe + head_dim + int8_t *sV = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t) + CTA_K * (head_dim + head_dim_pe) * sizeof(int8_t)); + half *sO = (half*)smem_; + + int32_t RS[num_tiles_q][num_tiles_k][8]; + int32_t RS_pe[num_tiles_q][num_tiles_k][8]; + float RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; + float d[num_tiles_q][2]; + + uint32_t q_scale_idx, k_scale_idx; + + // scale shape: (b, h_qo, (qo_len + BLKQ - 1) // BLKQ) + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * 4 + warp_idx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (4 * 8) + warp_idx * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_block_k * 4) + (head_id / num_kv_groups) * (num_block_k * 4) + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) ? 1 : 4; + + uint32_t Q_idx_lane_base = bx * CTA_Q + warp_idx * 16 + lane_id / 4; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + m[fq][0] = -5000000.0f; + m[fq][1] = -5000000.0f; + d[fq][0] = 1.0f; + d[fq][1] = 1.0f; + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + } + + __shared__ __align__(8) uint64_t barrier_Q; + __shared__ __align__(8) uint64_t barrier_K; + __shared__ __align__(8) uint64_t barrier_Q_pe; + __shared__ __align__(8) uint64_t barrier_K_pe; + __shared__ __align__(8) uint64_t barrier_V; + + if (threadIdx.x == 0) { + init_barrier(&barrier_Q, 1); + init_barrier(&barrier_K, 1); + init_barrier(&barrier_Q_pe, 1); + init_barrier(&barrier_K_pe, 1); + init_barrier(&barrier_V, 1); + } + + __syncthreads(); + + // load Q, K, V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_Q * (head_dim)) * sizeof(int8_t)>(&barrier_Q); + expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_Q * (head_dim_pe)) * sizeof(int8_t)>(&barrier_Q_pe); + expect_bytes<(CTA_K * (head_dim_pe)) * sizeof(int8_t)>(&barrier_K_pe); + expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_V); + + load_async_4D(sQ, &tensorMapQ, &barrier_Q, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sQ_pe, &tensorMapQ_pe, &barrier_Q_pe, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, 0, kv_head_id, batch_id); + load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, 0, kv_head_id, batch_id); + load_async_4D(sV, &tensorMapV, &barrier_V, 0, 0, kv_head_id, batch_id); + } + + float q_scale = Q_scale[q_scale_idx]; + float original_sm_scale = sm_scale; + + // wait for Q + wait(&barrier_Q, 0); + wait(&barrier_Q_pe, 0); + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + int p = 1; + for (uint32_t iter = 1; iter < num_iterations; iter++) + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (iter - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // wait for K + wait(&barrier_K, p); + wait(&barrier_K_pe, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); + wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); // add one line +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + if (k_it < num_tiles_qk_pe_inner) { + wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line + } + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // load K + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_K * head_dim_pe) * sizeof(int8_t)>(&barrier_K_pe); // add one line + load_async_4D(sK, &tensorMapK, &barrier_K, 0, iter * CTA_K, kv_head_id, batch_id); + load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, iter * CTA_K, kv_head_id, batch_id); // add one line + } + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]); // add one line + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + + // load V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sV, &tensorMapV, &barrier_V, iter * CTA_K, 0, kv_head_id, batch_id); + } + } + + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale; + + // wait for K + wait(&barrier_K, p); + wait(&barrier_K_pe, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); + wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + if (k_it < num_tiles_qk_pe_inner) { + wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line + } + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]) * dequant_scale; + } + } + } + + // masking +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 64 + 8 * ((k % 4) / 2); + const uint32_t k_idx = (num_iterations - 1) * CTA_K + fk * 16 + 2 * (lane_id % 4) + 8 * (k / 4) + k % 2; + + bool is_out_of_bounds; + + if constexpr (mask_mode == MaskMode::kCausal) + { + is_out_of_bounds = (k_idx > q_idx) || (k_idx >= kv_len); + } + else + { + is_out_of_bounds = (k_idx >= kv_len); + } + + if (is_out_of_bounds) + { + RS_f32[fq][fk][k] = -5000000.0f; + } + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + } + + normalize_d(RO, m, d); + + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; + #pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); + + #pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + warp_idx * 16 + (lane_id / 4)) * stride_seq_o + (lane_id % 4) * 2 ; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < head_dim/16; fv++) + { + if (Q_idx_lane_base + fq * 64 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[0]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[2]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[0]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[2]); + } + } + + if (Q_idx_lane_base + fq * 64 + 8 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[1]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[3]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[1]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[3]); + } + } + } + } +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& query_pe, + paddle::Tensor& key_pe, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(query_pe); + CHECK_CUDA(key_pe); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(query_pe); + CHECK_LASTDIM_CONTIGUOUS(key_pe); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(query_pe, paddle::DataType::INT8); + CHECK_DTYPE(key_pe, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(query_pe, 4); + CHECK_DIMS(key_pe, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; // 现在query是正常的128, 多出来的64在query_pe里面,所以这样做没什么问题 + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_type = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_type, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + constexpr int HEAD_DIM_PE = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k, stride_h_k, stride_seq_k); + + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, false>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + sMemSize += CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); // add extra space for qk pe + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_Q_pe, + tma_map_K_pe, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& query_pe, + paddle::Tensor& key_pe, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(query_pe); + CHECK_CUDA(key_pe); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(query_pe); + CHECK_LASTDIM_CONTIGUOUS(key_pe); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(query_pe, paddle::DataType::INT8); + CHECK_DTYPE(key_pe, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(query_pe, 4); + CHECK_DIMS(key_pe, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + constexpr int HEAD_DIM_PE = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, HEAD_DIM); + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k, stride_h_k, stride_seq_k); + + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, true>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + sMemSize += CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_Q_pe, + tma_map_K_pe, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} \ No newline at end of file From 0a52545357bd9774357118bba04ff537ef843c3f Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 25 Feb 2025 14:28:40 +0800 Subject: [PATCH 04/18] update kernels --- csrc/gpu/sage_attn_kernels/sageattn.cc | 540 ------ csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 106 +- .../sageattn_qk_int_sv_f16_kernel.cu | 1690 ----------------- .../sageattn_qk_int_sv_f16_kernel_sm80.cu | 1537 +++++++++++++++ .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 952 ---------- ...u => sageattn_qk_int_sv_f8_kernel_sm89.cu} | 621 +++++- .../sageattn_qk_int_sv_f8_kernel_sm90.cu | 84 +- csrc/gpu/sage_attn_kernels/sageattn_utils.cuh | 58 +- csrc/setup_cuda.py | 18 +- 9 files changed, 2342 insertions(+), 3264 deletions(-) delete mode 100644 csrc/gpu/sage_attn_kernels/sageattn.cc delete mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu delete mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu rename csrc/gpu/sage_attn_kernels/{sageattn_qk_int_sv_f8_kernel.cu => sageattn_qk_int_sv_f8_kernel_sm89.cu} (67%) diff --git a/csrc/gpu/sage_attn_kernels/sageattn.cc b/csrc/gpu/sage_attn_kernels/sageattn.cc deleted file mode 100644 index 10af0fbf5dfa..000000000000 --- a/csrc/gpu/sage_attn_kernels/sageattn.cc +++ /dev/null @@ -1,540 +0,0 @@ -#include "paddle/extension.h" - - -// -// ============== fp16 kernels registry, for sm80 arch ============== -// -// impl: sageattn_qk_int_sv_f16_kernel.cu -// attn buffer kernel -// std::vector qk_int8_sv_f16_accum_f16_attn_buf_fwd( -// paddle::Tensor& query, -// paddle::Tensor& key, -// paddle::Tensor& value, -// paddle::Tensor& output, -// paddle::Tensor& query_scale, -// paddle::Tensor& key_scale, -// int tensor_layout, -// int is_causal, -// int qk_quant_gran, -// float sm_scale, -// int return_lse); - -// std::vector> qk_int8_sv_f16_accum_f16_attn_buf_InferShape( -// std::vector query_shape, -// std::vector key_shape, -// std::vector value_shape, -// std::vector output_shape, -// std::vector query_scale_shape, -// std::vector key_scale_shape) { - -// // force layout: NHD: [bsz, seq_len, num_heads, head_dim] -// int64_t bsz = query_shape[0]; -// int64_t seq_len = query_shape[1]; -// int64_t h_qo = query_shape[2]; - -// std::vector return_shape = {bsz, h_qo, seq_len}; -// return {return_shape}; -// } - -// std::vector qk_int8_sv_f16_accum_f16_attn_buf_InferDtype( -// paddle::DataType A_dtype, -// paddle::DataType B_dtype, -// paddle::DataType C_dtype, -// paddle::DataType D_dtype, -// paddle::DataType E_dtype, -// paddle::DataType F_dtype) { -// return {paddle::DataType::FLOAT32}; -// } - -// PD_BUILD_OP(qk_int8_sv_f16_accum_f16_attn_buf) -// .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) -// .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) -// .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace -// .Attrs({"tensor_layout: int", -// "is_causal: int", -// "qk_quant_gran: int", -// "sm_scale: float", -// "return_lse: int"}) -// .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f16_attn_buf_fwd)) -// .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f16_attn_buf_InferShape)) -// .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f16_attn_buf_InferDtype)); - - -// attn forward kernel: sv f16 accumulator f32 -std::vector qk_int8_sv_f16_accum_f32_attn_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - - -std::vector> qk_int8_sv_f16_accum_f32_attn_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f16_accum_f32_attn_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f16_accum_f32_attn) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f32_attn_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f32_attn_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f32_attn_InferDtype)); - -// -// ============== fp8 kernels registry, for sm89 arch ============== -// - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - paddle::Tensor& value_scale, - paddle::Tensor& value_mean, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - -std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape, - std::vector value_scale_shape, - std::vector value_mean_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype, - paddle::DataType G_dtype, - paddle::DataType H_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale", "value_mean"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "out8", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}, {"value_mean", "out8"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype)); - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - paddle::Tensor& value_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - -std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape, - std::vector value_scale_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype, - paddle::DataType G_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype)); - - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - paddle::Tensor& value_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - -std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape, - std::vector value_scale_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype, - paddle::DataType G_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype)); - -// -// ============== fp8 kernels registry, for sm90 arch ============== -// - -std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - -std::vector> qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype)); - - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - paddle::Tensor& value_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse); - -std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape( - std::vector query_shape, - std::vector key_shape, - std::vector value_shape, - std::vector output_shape, - std::vector query_scale_shape, - std::vector key_scale_shape, - std::vector value_scale_shape) { - - // force layout: NHD: [bsz, seq_len, num_heads, head_dim] - int64_t bsz = query_shape[0]; - int64_t seq_len = query_shape[1]; - int64_t h_qo = query_shape[2]; - - std::vector return_shape = {bsz, h_qo, seq_len}; - return {return_shape}; -} - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype( - paddle::DataType A_dtype, - paddle::DataType B_dtype, - paddle::DataType C_dtype, - paddle::DataType D_dtype, - paddle::DataType E_dtype, - paddle::DataType F_dtype, - paddle::DataType G_dtype) { - return {paddle::DataType::FLOAT32}; -} - -PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90) - .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) - .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) - .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace - .Attrs({"tensor_layout: int", - "is_causal: int", - "qk_quant_gran: int", - "sm_scale: float", - "return_lse: int"}) - .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd)) - .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype)); - - -// -// ============== fused kernels registry ============== -// - -void quant_per_block_int8_fuse_sub_mean_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& mean, - paddle::Tensor& output, - paddle::Tensor& scale, - int block_size, - int tensor_layout); - -// quant_per_block_int8_fuse_sub_mean_cuda_fwd does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(quant_per_block_int8_fuse_sub_mean_cuda) - .Inputs({"input", "mean", "output", "scale"}) - .Outputs({"out1", "out2", "out3", "out4"}) - .SetInplaceMap({{"input", "out1"}, {"mean", "out2"}, {"output", "out3"}, {"scale", "out4"}}) // Inplace - .Attrs({"block_size: int", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(quant_per_block_int8_fuse_sub_mean_cuda_fwd)); - - -void quant_per_warp_int8_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - paddle::Tensor& scale, - int block_size, - int warp_block_size, - int tensor_layout); - -// quant_per_warp_int8_cuda_fwd does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(quant_per_warp_int8_cuda) - .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace - .Attrs({"block_size: int", "warp_block_size: int", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(quant_per_warp_int8_cuda_fwd)); - - -void quant_per_block_int8_cuda_scale_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - paddle::Tensor& scale, - float sm_scale, - int block_size, - int tensor_layout); - -// quant_per_block_int8_cuda_scale does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(quant_per_block_int8_cuda_scale) - .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace - .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_scale_fwd)); - - -void quant_per_block_int8_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - paddle::Tensor& scale, - int block_size, - int tensor_layout); - -// quant_per_block_int8_cuda does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(quant_per_block_int8_cuda) - .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace - .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_fwd)); - - -void transpose_pad_permute_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - int tensor_layout); - -// transpose_pad_permute_cuda_fwd does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(transpose_pad_permute_cuda) - .Inputs({"input", "output"}) - .Outputs({"out1", "out2"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}}) // Inplace - .Attrs({"tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(transpose_pad_permute_cuda_fwd)); - - -void scale_fuse_quant_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - paddle::Tensor& scale, - int num_tokens, - float scale_max, - int tensor_layout); - -// scale_fuse_quant_cuda_fwd does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(scale_fuse_quant_cuda) - .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace - .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(scale_fuse_quant_cuda_fwd)); - - -void mean_scale_fuse_quant_cuda_fwd( - paddle::Tensor& input, - paddle::Tensor& output, - paddle::Tensor& mean, - paddle::Tensor& scale, - int num_tokens, - float scale_max, - int tensor_layout); - -// mean_scale_fuse_quant_cuda_fwd does not have any return -// so we don't implement infer type & shape function here. - -PD_BUILD_OP(mean_scale_fuse_quant_cuda) - .Inputs({"input", "output", "mean", "scale"}) - .Outputs({"out1", "out2", "out3", "out4"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}}) // Inplace - .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) - .SetKernelFn(PD_KERNEL(mean_scale_fuse_quant_cuda_fwd)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 501e0eb89641..84d25794bb36 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -83,13 +83,7 @@ __global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int if constexpr (sub_mean) { - // *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); - // for unable-align reasons, we unroll it manually. -#pragma unroll - for (int ii = 0; ii < 8; ii++) { - mean_val[ii] = mean_ptr_base[ii]; - } - + *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); // 8 elements #pragma unroll for (uint32_t j = 0; j < 8; j++) { @@ -104,12 +98,7 @@ __global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int { if (thread_base_token + i * iter_stride < num_tokens) { - // *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); - // for unable-align reasons, we unroll it manually. -#pragma unroll - for (int ii = 0; ii < 8; ii++) { - x_val[i][ii] = *(input_ptr_base + i * iter_stride * stride_seq_input + ii); - } + *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); #pragma unroll for (uint32_t j = 0; j < 8; j++) { @@ -247,12 +236,12 @@ __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ __syncthreads(); - // *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); + *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); // for unable-align reasons, we unroll it manually. -#pragma unroll - for (int i = 0; i < 8; i++) { - *(output_ptr_base + i) = shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size + i]; // TODO: not debugged, maybe some problem - } +// #pragma unroll +// for (int i = 0; i < 8; i++) { +// *(output_ptr_base + i) = shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size + i]; // TODO: not debugged, maybe some problem +// } } template @@ -290,11 +279,11 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp for (int i = 0; i < num_iters; i++) { - // *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); -#pragma unroll - for (int ii = 0; ii < 8; ii++) { - x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged - } + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +// #pragma unroll +// for (int ii = 0; ii < 8; ii++) { +// x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged +// } #pragma unroll for (uint32_t j = 0; j < 8; j++) { @@ -350,11 +339,11 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp for (int i = 0; i < num_iters; i++) { - // *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); -#pragma unroll - for (int ii = 0; ii < 8; ii++) { - x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged - } + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +// #pragma unroll +// for (int ii = 0; ii < 8; ii++) { +// x_val[ii] = *(input_ptr_base + i * gmem_stride + ii); // TODO: not debugged +// } #pragma unroll for (uint32_t j = 0; j < 8; j++) { @@ -438,10 +427,9 @@ void quant_per_block_int8_fuse_sub_mean_cuda_fwd( auto mean_dtype = mean.dtype(); PD_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); - DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, { CHECK_SHAPE(mean, batch_size, num_heads, head_dim); CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); @@ -451,9 +439,6 @@ void quant_per_block_int8_fuse_sub_mean_cuda_fwd( constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - std::cout << "resources: " << (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE << " " << num_heads << " " <<<>>( reinterpret_cast(input.data()), reinterpret_cast(mean.data()), @@ -471,6 +456,13 @@ void quant_per_block_int8_fuse_sub_mean_cuda_fwd( }); } +PD_BUILD_OP(quant_per_block_int8_fuse_sub_mean_cuda) + .Inputs({"input", "mean", "output", "scale"}) + .Outputs({"out1", "out2", "out3", "out4"}) + .SetInplaceMap({{"input", "out1"}, {"mean", "out2"}, {"output", "out3"}, {"scale", "out4"}}) // Inplace + .Attrs({"block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_fuse_sub_mean_cuda_fwd)); + void quant_per_warp_int8_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -527,7 +519,7 @@ void quant_per_warp_int8_cuda_fwd( DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, { DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, { CHECK_SHAPE(output, input.shape()[0], input.shape()[1], input.shape()[2], input.shape()[3]); CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE)); dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE), num_heads, batch_size); @@ -552,6 +544,13 @@ void quant_per_warp_int8_cuda_fwd( }); } +PD_BUILD_OP(quant_per_warp_int8_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"block_size: int", "warp_block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_warp_int8_cuda_fwd)); + void quant_per_block_int8_cuda_scale_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -635,6 +634,13 @@ void quant_per_block_int8_cuda_scale_fwd( }); } +PD_BUILD_OP(quant_per_block_int8_cuda_scale) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_scale_fwd)); + void quant_per_block_int8_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -717,6 +723,14 @@ void quant_per_block_int8_cuda_fwd( }); } +PD_BUILD_OP(quant_per_block_int8_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_fwd)); + +// quant v用,但是v不是192,所以可以沿用原来的DISPATCH_HEAD_DIM void transpose_pad_permute_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -792,6 +806,13 @@ void transpose_pad_permute_cuda_fwd( }); } +PD_BUILD_OP(transpose_pad_permute_cuda) + .Inputs({"input", "output"}) + .Outputs({"out1", "out2"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}}) // Inplace + .Attrs({"tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(transpose_pad_permute_cuda_fwd)); + void scale_fuse_quant_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -869,6 +890,14 @@ void scale_fuse_quant_cuda_fwd( }); } +PD_BUILD_OP(scale_fuse_quant_cuda) + .Inputs({"input", "output", "scale"}) + .Outputs({"out1", "out2", "out3"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(scale_fuse_quant_cuda_fwd)); + +// smooth v void mean_scale_fuse_quant_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, @@ -950,4 +979,11 @@ void mean_scale_fuse_quant_cuda_fwd( scale.strides()[0], scale.strides()[1] ); }); -} \ No newline at end of file +} + +PD_BUILD_OP(mean_scale_fuse_quant_cuda) + .Inputs({"input", "output", "mean", "scale"}) + .Outputs({"out1", "out2", "out3", "out4"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}}) // Inplace + .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .SetKernelFn(PD_KERNEL(mean_scale_fuse_quant_cuda_fwd)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu deleted file mode 100644 index e6e5d0daa270..000000000000 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel.cu +++ /dev/null @@ -1,1690 +0,0 @@ -#include - -#include "paddle/extension.h" - -// #include "sageattn.h" -#include "sageattn_utils.cuh" - -#define PACK_SIZE_QK 16 // as if it is int8 -#define PACK_SIZE_V 16 // fp8 -#define PACK_SIZE_O 8 // fp16 - -// treat as if int8 tensor core -#define MMA_QK_M 16 -#define MMA_QK_N 16 -#define MMA_QK_K 32 - -// fp8 tensor core -#define MMA_SV_M 16 -#define MMA_SV_N 16 -#define MMA_SV_K 32 - -// qk_int_sv_f16_buffer -// when instantiating, the head dim = 64, which makes the V_STRIDE = 64, then div 16 = 4, -// which triggered the compiling fault. -// it is the macro: PACK_SIZE_V and MMA_SV_K's problem, so we will redefine them here: -#ifdef PACK_SIZE_V -#define PACK_SIZE_V 8 -#endif - -#ifdef MMA_SV_K -#define MMA_SV_K 16 -#endif - -// inner impl -template -__global__ void qk_int_sv_f16_attn_buffer_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, - float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, - const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, - const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, - const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, - const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, - const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, - float sm_scale) -{ - // compile time check - static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); - static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); - static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); - static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); - static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); - static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation - - using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; - - constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; - constexpr uint32_t num_warps_k = CTA_K / WARP_K; - constexpr uint32_t num_warps = num_warps_q * num_warps_k; - constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; - constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; - constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); - constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; - - constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); - constexpr uint32_t O_SMEM_STRIDE = head_dim; - constexpr uint32_t V_SMEM_STRIDE = head_dim; - - extern __shared__ int8_t smem[]; - - const uint32_t lane_id = get_lane_id(); - const uint32_t warp_id = get_warp_id(); - - // maximize L2 hit rate - const uint32_t batch_id = blockIdx.z; - const uint32_t bx = blockIdx.x; - const uint32_t num_qo_heads = gridDim.y; - const uint32_t head_id = blockIdx.y; - - // transfer to base 2 instead of base e with better numerical efficiency - sm_scale *= math::log2e; - - // RS holds the fragment of S - int32_t RS[num_tiles_q][num_tiles_k][8]; - half RO[num_tiles_q][num_tiles_v][8]; - float m[num_tiles_q][2]; // max - float d[num_tiles_q][2]; // denominator - - float m_buf[num_tiles_q][2]; // buffer for m - float RO_buf[num_tiles_q][num_tiles_v][8]; // buffer for RO - - uint32_t q_scale_idx, k_scale_idx; - - if constexpr (Q_GRAN == QuantGranularity::kPerBlock) - { - const uint32_t num_block_q = gridDim.x; - q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; - } - else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_warp_block_q = gridDim.x * num_warps_q; - q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); - } - else if constexpr (Q_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_warp_block_q = gridDim.x * num_warps_q; - q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; - } - - if constexpr (K_GRAN == QuantGranularity::kPerBlock) - { - const uint32_t num_block_k = div_ceil(kv_len, CTA_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; - } - else if constexpr (K_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); - } - else if constexpr (K_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; - } - - constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; - - // initialize o, m, d -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - -#pragma unroll - for (uint32_t k = 0; k < 4; k++) - { - ((int32_t*)RO[fq][fv])[k] = 0; - } - -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RO_buf[fq][fv][k] = 0.0f; - } - } - } -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t k = 0; k < 2; k++) - { - m[fq][k] = -5000000.0f; - m_buf[fq][k] = -5000000.0f; - d[fq][k] = 1.0f; - } - } - - constexpr uint32_t K_smem_idx_offset = CTA_Q; - constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; - - constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_Q(smem); - smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); - constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); - constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_O(smem); - - constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; - constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; - constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; - - constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); - constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); - constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); - constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); - constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); - constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); - constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); - - int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; - int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; - half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; - uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); - uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); - uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); - - uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); - uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); - uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); - - // for causal masking - uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; - uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); - - // for loading - uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; - uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; - uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; - - const uint32_t num_iterations = div_ceil( - mask_mode == MaskMode::kCausal - ? min(kv_len, (bx + 1) * CTA_Q) - : kv_len, - CTA_K); - - // load Q with predicate - load_global_to_share( - &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); - cp_async::commit_group(); - cp_async::wait_group<0>(); - __syncthreads(); - - // for num_tiles_qk_inner = 1, we load all Qs in register - uint32_t RQ[num_tiles_q][4]; - if constexpr (num_tiles_qk_inner == 1) - { -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); - Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); - } - } - - // load K with predicate - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - float q_scale = Q_scale[q_scale_idx]; - - float original_sm_scale = sm_scale; - float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; - - sm_scale = original_sm_scale * dequant_scale; - - // load V with predicate - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - - uint32_t num_flush_times = div_ceil(num_iterations, Buffer_Iter) - (num_iterations % Buffer_Iter == 1); // leave at least two iterations for the last flush - uint32_t iter = 1; - -#pragma unroll - for (uint32_t flush_time = 0; flush_time < num_flush_times - 1; flush_time++) - { -#pragma unroll - for (; iter <= (flush_time + 1) * Buffer_Iter; iter++) - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - - #pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - #pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { - #pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); - } - } - } - - // do not apply causal mask and out of bound mask for these iterations - K_idx_lane_base += CTA_K; - - update_mdo(RS_f32, RO, m, d, sm_scale); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - __syncthreads(); - - // load K - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); - cp_async::commit_group(); - - dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // ensure V is ready - cp_async::wait_group<1>(); - __syncthreads(); - - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - - __syncthreads(); - // load V - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); - cp_async::commit_group(); - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - } - - // update buffer -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t k = 0; k < 2; k++) - { - float o_scale = math::ptx_exp2(m_buf[fq][k] - m[fq][k]); -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - // update buffer - RO_buf[fq][fv][k * 2 + 0] = RO_buf[fq][fv][k * 2 + 0] * o_scale + __half2float(RO[fq][fv][k * 2 + 0]); - RO_buf[fq][fv][k * 2 + 1] = RO_buf[fq][fv][k * 2 + 1] * o_scale + __half2float(RO[fq][fv][k * 2 + 1]); - RO_buf[fq][fv][k * 2 + 4] = RO_buf[fq][fv][k * 2 + 4] * o_scale + __half2float(RO[fq][fv][k * 2 + 4]); - RO_buf[fq][fv][k * 2 + 5] = RO_buf[fq][fv][k * 2 + 5] * o_scale + __half2float(RO[fq][fv][k * 2 + 5]); - - // update m_buf - m_buf[fq][k] = m[fq][k]; - - // clear RO - *((int32_t*)&RO[fq][fv][k * 2 + 0]) = 0; - *((int32_t*)&RO[fq][fv][k * 2 + 4]) = 0; - } - } - } - } - -#pragma unroll - for (; iter < num_iterations - 1; iter++) - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); - } - } - } - - // do not apply causal mask and out of bound mask for these iterations - K_idx_lane_base += CTA_K; - - update_mdo(RS_f32, RO, m, d, sm_scale); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - __syncthreads(); - - // load K - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); - cp_async::commit_group(); - - dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // ensure V is ready - cp_async::wait_group<1>(); - __syncthreads(); - - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - - __syncthreads(); - // load V - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); - cp_async::commit_group(); - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - } - - // second last iter, apply causal mask - if (num_iterations > 1) - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; - } - } - } - - if constexpr (mask_mode == MaskMode::kCausal) - { - apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); - } - // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); - K_idx_lane_base += CTA_K; - - update_mdo(RS_f32, RO, m, d, original_sm_scale); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - __syncthreads(); - - // load K with predicate - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // ensure V is ready - cp_async::wait_group<1>(); - __syncthreads(); - - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - - __syncthreads(); - // load V with predicate - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); - cp_async::commit_group(); - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - } - - // last iter, apply causal mask and out of bound mask - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; - } - } - } - - if constexpr (mask_mode == MaskMode::kCausal) - { - apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); - } - // check out of bound in the last iter - apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); - K_idx_lane_base += CTA_K; - - update_mdo(RS_f32, RO, m, d, original_sm_scale); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - // ensure V is ready - cp_async::wait_group<0>(); - __syncthreads(); - - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - - __syncthreads(); - - } - - // update buffer -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t k = 0; k < 2; k++) - { - float o_scale = math::ptx_exp2(m_buf[fq][k] - m[fq][k]); -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - // update buffer - RO_buf[fq][fv][k * 2 + 0] = RO_buf[fq][fv][k * 2 + 0] * o_scale + __half2float(RO[fq][fv][k * 2 + 0]); - RO_buf[fq][fv][k * 2 + 1] = RO_buf[fq][fv][k * 2 + 1] * o_scale + __half2float(RO[fq][fv][k * 2 + 1]); - RO_buf[fq][fv][k * 2 + 4] = RO_buf[fq][fv][k * 2 + 4] * o_scale + __half2float(RO[fq][fv][k * 2 + 4]); - RO_buf[fq][fv][k * 2 + 5] = RO_buf[fq][fv][k * 2 + 5] * o_scale + __half2float(RO[fq][fv][k * 2 + 5]); - - // update m_buf - // m_buf[fq][k] = m[fq][k]; - - // // clear RO - // *((int32_t*)&RO[fq][fv][k * 2 + 0]) = 0; - // *((int32_t*)&RO[fq][fv][k * 2 + 4]) = 0; - } - } - } - - // TODO: thread block sync mdo state for num_warps_k > 0 - - normalize_d(RO_buf, m, d); - - // save the result to shared memory - uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); - - // convert RO_buf to half - uint32_t RO_f16[4]; -#pragma unroll - for (uint32_t k = 0; k < 4; k++) - { - if constexpr (std::is_same::value) - { - ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO_buf[fq][fv])[k]); - } - else if constexpr (std::is_same::value) - { - ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO_buf[fq][fv])[k]); - } - } - - ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; - ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; - - // ! permuted, make sure you know what you are doing - ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; - ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; - - } - } - - // ! do we need to sync here? - __syncwarp(); - - // shared memory to global memory - DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; - uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); - uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; - -#pragma unroll - for (uint32_t i = 0; i < O_smem_iters_col; i++) - { -#pragma unroll - for (uint32_t j = 0; j < O_smem_iters_row; j++) - { - if (O_load_idx_lane_base < qo_len) - { - smem_O.store_128b(offset_O, O_lane_ptr); - } - O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); - offset_O = smem_O.advance_offset_by_column(offset_O); - } - - offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); - O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); - O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; - } - - if constexpr (return_lse) - { - // ! this only works for num_tiles_q = 2 - uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); - float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; - uint32_t fq = (lane_id % 4) / 2; - uint32_t k = (lane_id % 4) % 2; - - if (lse_idx < qo_len) - { - lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); // TODO: here has some bug. - } - } - -} - -// impl -> see sageattn.h file -// tensor_layout 0 for [B, N, H, D] (NHD, b, s, head, dim), -// 1 for [B, H, N, D] (HND) -// std::vector qk_int8_sv_f16_accum_f16_attn_buf_fwd(paddle::Tensor& query, -// paddle::Tensor& key, -// paddle::Tensor& value, -// paddle::Tensor& output, -// paddle::Tensor& query_scale, -// paddle::Tensor& key_scale, -// int tensor_layout, -// int is_causal, -// int qk_quant_gran, -// float sm_scale, -// int return_lse) -// { -// CHECK_CUDA(query); -// CHECK_CUDA(key); -// CHECK_CUDA(value); -// CHECK_CUDA(output); -// CHECK_CUDA(query_scale); -// CHECK_CUDA(key_scale); - -// CHECK_CONTIGUOUS(query); -// CHECK_CONTIGUOUS(key); -// CHECK_LASTDIM_CONTIGUOUS(value); -// CHECK_LASTDIM_CONTIGUOUS(output); -// CHECK_CONTIGUOUS(query_scale); -// CHECK_CONTIGUOUS(key_scale); - -// CHECK_DTYPE(query, paddle::DataType::INT8); -// CHECK_DTYPE(key, paddle::DataType::INT8); -// CHECK_DTYPE(value, paddle::DataType::FLOAT16); // TODO: there maybe some problem, for bf16 type -// CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); -// CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); - -// CHECK_DIMS(query, 4); -// CHECK_DIMS(key, 4); -// CHECK_DIMS(value, 4); -// CHECK_DIMS(output, 4); -// CHECK_DIMS(query_scale, 3); -// CHECK_DIMS(key_scale, 3); - -// const int head_dim = query.shape()[3]; -// const int batch_size = query.shape()[0]; - -// int stride_bz_q = query.strides()[0]; -// int stride_bz_k = key.strides()[0]; -// int stride_bz_v = value.strides()[0]; -// int stride_bz_o = output.strides()[0]; - -// int qo_len, kv_len, num_qo_heads, num_kv_heads; -// int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; -// int stride_h_q, stride_h_k, stride_h_v, stride_h_o; - -// if (tensor_layout == 0) -// { -// qo_len = query.shape()[1]; -// kv_len = key.shape()[1]; -// num_qo_heads = query.shape()[2]; -// num_kv_heads = key.shape()[2]; -// CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); -// CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); - -// stride_seq_q = query.strides()[1]; -// stride_seq_k = key.strides()[1]; -// stride_seq_v = value.strides()[1]; -// stride_seq_o = output.strides()[1]; - -// stride_h_q = query.strides()[2]; -// stride_h_k = key.strides()[2]; -// stride_h_v = value.strides()[2]; -// stride_h_o = output.strides()[2]; -// } -// else if (tensor_layout == 1) -// { -// qo_len = query.shape()[2]; -// kv_len = key.shape()[2]; -// num_qo_heads = query.shape()[1]; -// num_kv_heads = key.shape()[1]; -// CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); -// CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); - -// stride_seq_q = query.strides()[2]; -// stride_seq_k = key.strides()[2]; -// stride_seq_v = value.strides()[2]; -// stride_seq_o = output.strides()[2]; - -// stride_h_q = query.strides()[1]; -// stride_h_k = key.strides()[1]; -// stride_h_v = value.strides()[1]; -// stride_h_o = output.strides()[1]; -// } -// else -// { -// throw std::invalid_argument("tensor_layout must be 0 or 1"); -// } - -// if (num_qo_heads % num_kv_heads != 0) { -// std::ostringstream err_msg; -// err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; -// throw std::invalid_argument(err_msg.str()); -// } - -// const int num_kv_groups = num_qo_heads / num_kv_heads; - -// paddle::Tensor lse = paddle::empty({1}); -// if (return_lse) -// { -// lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); -// } - -// auto output_dtype = output.dtype(); // in [bfloat16 or float16] - -// DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { -// DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { -// DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { -// DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { -// DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { -// constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; -// constexpr int CTA_K = (HEAD_DIM == 256) ? 32 : 64; -// constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; -// constexpr int WARP_K = (HEAD_DIM == 256) ? 32 : 64; - -// constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; - -// if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) -// { -// CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); -// CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); -// } -// else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) -// { -// CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); -// CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); -// } -// else -// { -// static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); -// } - -// // smem_Q smem_K smem_V smem_O -// size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); - -// auto kernel_func = qk_int_sv_f16_attn_buffer_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, ComputeUnit::kTensorCore, -// mask_mode, 32, RETURN_LSE>; - -// cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); - -// dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); -// dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - -// // using PD216bit = typename pd_cvt::PD16bitTrait<>::DataType; -// using PD16bitRe = typename pd_cvt::PD16bitReTrait::DataType; - -// PD16bitRe* output_data = output.data(); - -// kernel_func<<>>( -// query.data(), -// key.data(), -// reinterpret_cast(value.data()), // reinterpret_cast(reinterpret_cast(value.data())) -// reinterpret_cast(output.data()), -// (RETURN_LSE) ? lse.data() : nullptr, -// query_scale.data(), -// key_scale.data(), -// nullptr, -// qo_len, -// kv_len, -// num_kv_groups, -// stride_bz_q, stride_seq_q, stride_h_q, -// stride_bz_k, stride_seq_k, stride_h_k, -// stride_bz_v, stride_seq_v, stride_h_v, -// stride_bz_o, stride_seq_o, stride_h_o, -// sm_scale); -// }); -// }); -// }); -// }); -// }); - -// return {lse}; -// } - -// qk_int_sv_f16 impl -// the previous one stands for buffer -template -__global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, - float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, - const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, - const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, - const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, - const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, - const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, - float sm_scale) -{ - // compile time check - static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); - static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); - static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); - static_assert(std::is_same::value || !use_inst_buffer, "use_inst_buffer only supports DTypeSVAccum as float"); - static_assert(std::is_same::value || std::is_same::value, "DTypeSVAccum must be float or half"); - static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); - static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); - static_assert(!fuse_v_mean || std::is_same::value, "fuse_v_mean only supports half"); - static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation - - using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; - - constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; - constexpr uint32_t num_warps_k = CTA_K / WARP_K; - constexpr uint32_t num_warps = num_warps_q * num_warps_k; - constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; - constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; - constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); - constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; - - constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); - constexpr uint32_t O_SMEM_STRIDE = head_dim; - constexpr uint32_t V_SMEM_STRIDE = head_dim; - - extern __shared__ int8_t smem[]; - - const uint32_t lane_id = get_lane_id(); - const uint32_t warp_id = get_warp_id(); - - // maximize L2 hit rate - const uint32_t batch_id = blockIdx.z; - const uint32_t bx = blockIdx.x; - const uint32_t num_qo_heads = gridDim.y; - const uint32_t head_id = blockIdx.y; - - // transfer to base 2 instead of base e with better numerical efficiency - sm_scale *= math::log2e; - - // RS holds the fragment of S - int32_t RS[num_tiles_q][num_tiles_k][8]; - DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; - float m[num_tiles_q][2]; // max - float d[num_tiles_q][2]; // denominator - - uint32_t q_scale_idx, k_scale_idx; - - if constexpr (Q_GRAN == QuantGranularity::kPerBlock) - { - const uint32_t num_block_q = gridDim.x; - q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; - } - else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_warp_block_q = gridDim.x * num_warps_q; - q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); - } - else if constexpr (Q_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_warp_block_q = gridDim.x * num_warps_q; - q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; - } - - if constexpr (K_GRAN == QuantGranularity::kPerBlock) - { - const uint32_t num_block_k = div_ceil(kv_len, CTA_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; - } - else if constexpr (K_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); - } - else if constexpr (K_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; - } - - constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; - - // initialize o, m, d -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - if constexpr (std::is_same::value) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RO[fq][fv][k] = 0.0f; - } - } - else if constexpr (std::is_same::value) - { -#pragma unroll - for (uint32_t k = 0; k < 4; k++) - { - ((int32_t*)RO[fq][fv])[k] = 0; - } - } - } - } -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t k = 0; k < 2; k++) - { - m[fq][k] = -5000000.0f; - d[fq][k] = 1.0f; - } - } - - constexpr uint32_t K_smem_idx_offset = CTA_Q; - constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; - - constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_Q(smem); - smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); - constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); - constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; - smem_t smem_O(smem); - - constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; - constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; - constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; - constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; - - constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); - constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); - constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); - constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); - constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); - constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); - constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); - - int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; - int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; - half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; - uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); - uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); - uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); - - uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); - uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); - uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); - - // for causal masking - uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; - uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); - - // for loading - uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; - uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; - uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; - - const uint32_t num_iterations = div_ceil( - mask_mode == MaskMode::kCausal - ? min(kv_len, (bx + 1) * CTA_Q) - : kv_len, - CTA_K); - - // load Q with predicate - load_global_to_share( - &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); - cp_async::commit_group(); - cp_async::wait_group<0>(); - __syncthreads(); - - // for num_tiles_qk_inner = 1, we load all Qs in register - uint32_t RQ[num_tiles_q][4]; - if constexpr (num_tiles_qk_inner == 1) - { -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); - Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); - } - } - - // load K with predicate - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - float q_scale = Q_scale[q_scale_idx]; - - float original_sm_scale = sm_scale; - float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; - - sm_scale = original_sm_scale * dequant_scale; - - // load V with predicate - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - -#pragma unroll - for (uint32_t iter = 1; iter < num_iterations - 1; iter++) - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); - } - } - } - - // do not apply causal mask and out of bound mask for these iterations - K_idx_lane_base += CTA_K; - - if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, sm_scale); - } - else if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, sm_scale); - } - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - __syncthreads(); - - // load K - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); - cp_async::commit_group(); - - dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // ensure V is ready - cp_async::wait_group<1>(); - __syncthreads(); - - if constexpr (!use_inst_buffer) - { - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - else - { - compute_fp16_sv_permuted_inst_buf( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - - __syncthreads(); - // load V - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); - cp_async::commit_group(); - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - } - - // second last iter, apply causal mask - if (num_iterations > 1) - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; - } - } - } - - if constexpr (mask_mode == MaskMode::kCausal) - { - apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); - } - // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); - K_idx_lane_base += CTA_K; - - if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, original_sm_scale); - } - else if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, original_sm_scale); - } - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - __syncthreads(); - - // load K with predicate - load_global_to_share( - &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); - cp_async::commit_group(); - - dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // ensure V is ready - cp_async::wait_group<1>(); - __syncthreads(); - - if constexpr (!use_inst_buffer) - { - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - else - { - compute_fp16_sv_permuted_inst_buf( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - - __syncthreads(); - // load V with predicate - load_global_to_share( - &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); - cp_async::commit_group(); - K_load_idx_lane_base += CTA_K; - V_load_idx_lane_base += CTA_K; - } - - // last iter, apply causal mask and out of bound mask - { - // ensure K is ready - cp_async::wait_group<1>(); - __syncthreads(); - - // compute QK^T - if constexpr (num_tiles_qk_inner == 1) - { - compute_int_qk( - smem_K, RS, RQ, K_smem_offset_mma); - } - else - { - compute_int_qk( - smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); - } - - float RS_f32[num_tiles_q][num_tiles_k][8]; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; - } - } - } - - if constexpr (mask_mode == MaskMode::kCausal) - { - apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); - } - // check out of bound in the last iter - apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); - K_idx_lane_base += CTA_K; - - if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, original_sm_scale); - } - else if constexpr (std::is_same::value) - { - update_mdo(RS_f32, RO, m, d, original_sm_scale); - } - - if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) - { - accumulate_d(RS_f32, d); - } - - uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; - RS_32_to_16(RS_f32, RS_f16); - - if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) - { - accumulate_d(RS_f16, d); - } - - // ensure V is ready - cp_async::wait_group<0>(); - __syncthreads(); - - if constexpr (!use_inst_buffer) - { - compute_fp16_sv_permuted( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - else - { - compute_fp16_sv_permuted_inst_buf( - smem_V, RS_f16, RO, d, V_smem_offset_mma); - } - - __syncthreads(); - - } - - // TODO: thread block sync mdo state for num_warps_k > 0 - - normalize_d(RO, m, d); - - // save the result - // if (get_warp_idx_k() == 0) - // { - - // convert half to bfloat16 - if constexpr (std::is_same::value && std::is_same::value) - { -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - ((nv_bfloat162*)RO[fq][fv])[0] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[0])); - ((nv_bfloat162*)RO[fq][fv])[1] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[1])); - ((nv_bfloat162*)RO[fq][fv])[2] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[2])); - ((nv_bfloat162*)RO[fq][fv])[3] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[3])); - } - } - } - - // add v_mean - if constexpr (fuse_v_mean) - { - DTypeOut2 v_mean[2]; - DTypeOut *V_mean_lane_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + lane_id % 4 * 2; -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - v_mean[0] = *((DTypeOut2*)(V_mean_lane_ptr + fv * 16)); - v_mean[1] = *((DTypeOut2*)(V_mean_lane_ptr + 8 + fv * 16)); -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - ((DTypeOut2*)RO[fq][fv])[0] = __hadd2(((DTypeOut2*)RO[fq][fv])[0], v_mean[0]); - ((DTypeOut2*)RO[fq][fv])[1] = __hadd2(((DTypeOut2*)RO[fq][fv])[1], v_mean[0]); - ((DTypeOut2*)RO[fq][fv])[2] = __hadd2(((DTypeOut2*)RO[fq][fv])[2], v_mean[1]); - ((DTypeOut2*)RO[fq][fv])[3] = __hadd2(((DTypeOut2*)RO[fq][fv])[3], v_mean[1]); - } - } - } - - // save the result to shared memory - uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); - - if constexpr (std::is_same::value) - { - // convert RO to half - uint32_t RO_f16[4]; -#pragma unroll - for (uint32_t k = 0; k < 4; k++) - { - if constexpr (std::is_same::value) - { - ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); - } - else if constexpr (std::is_same::value) - { - ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); - } - } - - ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; - ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; - - // ! permuted, make sure you know what you are doing - ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; - ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; - } - else if constexpr (std::is_same::value) - { - ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((int32_t*)RO[fq][fv])[0]; - ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[1]; - - // ! permuted, make sure you know what you are doing - ((int32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[2]; - ((int32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((int32_t*)RO[fq][fv])[3]; - } - } - } - - // ! do we need to sync here? - __syncwarp(); - - // shared memory to global memory - DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; - uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); - uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; - -#pragma unroll - for (uint32_t i = 0; i < O_smem_iters_col; i++) - { -#pragma unroll - for (uint32_t j = 0; j < O_smem_iters_row; j++) - { - if (O_load_idx_lane_base < qo_len) - { - smem_O.store_128b(offset_O, O_lane_ptr); - } - O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); - offset_O = smem_O.advance_offset_by_column(offset_O); - } - - offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); - O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); - O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; - } - - if constexpr (return_lse) - { - // ! this only works for num_tiles_q = 2 - uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); - float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; - uint32_t fq = (lane_id % 4) / 2; - uint32_t k = (lane_id % 4) % 2; - - if (lse_idx < qo_len) - { - lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); - } - } -} - -// tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] -// impl -> see sageattn.h file -std::vector qk_int8_sv_f16_accum_f32_attn_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse) -{ - CHECK_CUDA(query); - CHECK_CUDA(key); - CHECK_CUDA(value); - CHECK_CUDA(output); - CHECK_CUDA(query_scale); - CHECK_CUDA(key_scale); - - CHECK_CONTIGUOUS(query); - CHECK_CONTIGUOUS(key); - CHECK_LASTDIM_CONTIGUOUS(value); - CHECK_LASTDIM_CONTIGUOUS(output); - CHECK_CONTIGUOUS(query_scale); - CHECK_CONTIGUOUS(key_scale); - - CHECK_DTYPE(query, paddle::DataType::INT8); - CHECK_DTYPE(key, paddle::DataType::INT8); - CHECK_DTYPE(value, paddle::DataType::FLOAT16); - CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); - CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); - - CHECK_DIMS(query, 4); - CHECK_DIMS(key, 4); - CHECK_DIMS(value, 4); - CHECK_DIMS(output, 4); - CHECK_DIMS(query_scale, 3); - CHECK_DIMS(key_scale, 3); - - const int head_dim = query.shape()[3]; - const int batch_size = query.shape()[0]; - - int stride_bz_q = query.strides()[0]; - int stride_bz_k = key.strides()[0]; - int stride_bz_v = value.strides()[0]; - int stride_bz_o = output.strides()[0]; - - int qo_len, kv_len, num_qo_heads, num_kv_heads; - int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; - int stride_h_q, stride_h_k, stride_h_v, stride_h_o; - - if (tensor_layout == 0) - { - qo_len = query.shape()[1]; - kv_len = key.shape()[1]; - num_qo_heads = query.shape()[2]; - num_kv_heads = key.shape()[2]; - CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); - CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); - - stride_seq_q = query.strides()[1]; - stride_seq_k = key.strides()[1]; - stride_seq_v = value.strides()[1]; - stride_seq_o = output.strides()[1]; - - stride_h_q = query.strides()[2]; - stride_h_k = key.strides()[2]; - stride_h_v = value.strides()[2]; - stride_h_o = output.strides()[2]; - } - else if (tensor_layout == 1) - { - qo_len = query.shape()[2]; - kv_len = key.shape()[2]; - num_qo_heads = query.shape()[1]; - num_kv_heads = key.shape()[1]; - CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); - CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); - - stride_seq_q = query.strides()[2]; - stride_seq_k = key.strides()[2]; - stride_seq_v = value.strides()[2]; - stride_seq_o = output.strides()[2]; - - stride_h_q = query.strides()[1]; - stride_h_k = key.strides()[1]; - stride_h_v = value.strides()[1]; - stride_h_o = output.strides()[1]; - } - else - { - throw std::invalid_argument("tensor_layout must be 0 or 1"); - } - - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; - throw std::invalid_argument(err_msg.str()); - } - - const int num_kv_groups = num_qo_heads / num_kv_heads; - - paddle::Tensor lse = paddle::empty({1}); - if (return_lse) - { - lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); - } - - auto output_dtype = output.dtype(); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { - DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { - DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { - DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { - constexpr int CTA_Q = 128; - constexpr int CTA_K = 64; - constexpr int WARP_Q = 32; - constexpr int WARP_K = 64; - - constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; - - if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); - } - else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); - } - else - { - static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); - } - - // smem_Q smem_K smem_V smem_O - size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); - - auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, false, DTypeOut, ComputeUnit::kTensorCore, - mask_mode, RETURN_LSE, false>; - - cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); - - dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); - dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - - kernel_func<<>>( - query.data(), - key.data(), - reinterpret_cast(value.data()), - reinterpret_cast(output.data()), - (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, - reinterpret_cast(query_scale.data()), - reinterpret_cast(key_scale.data()), - nullptr, - qo_len, - kv_len, - num_kv_groups, - stride_bz_q, stride_seq_q, stride_h_q, - stride_bz_k, stride_seq_k, stride_h_k, - stride_bz_v, stride_seq_v, stride_h_v, - stride_bz_o, stride_seq_o, stride_h_o, - sm_scale); - }); - }); - }); - }); - }); - - return {lse}; -} \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu new file mode 100644 index 000000000000..cb666b9a7a35 --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu @@ -0,0 +1,1537 @@ +#include + +#include "paddle/extension.h" + +// #include "sageattn.h" +#include "sageattn_utils.cuh" + +#define PACK_SIZE_QK 16 // as if it is int8 +#define PACK_SIZE_V 16 // fp8 +#define PACK_SIZE_O 8 // fp16 + +// treat as if int8 tensor core +#define MMA_QK_M 16 +#define MMA_QK_N 16 +#define MMA_QK_K 32 + +// fp8 tensor core +#define MMA_SV_M 16 +#define MMA_SV_N 16 +#define MMA_SV_K 32 + +// qk_int_sv_f16_buffer +// when instantiating, the head dim = 64, which makes the V_STRIDE = 64, then div 16 = 4, +// which triggered the compiling fault. +// it is the macro: PACK_SIZE_V and MMA_SV_K's problem, so we will redefine them here: +#ifdef PACK_SIZE_V +#define PACK_SIZE_V 8 +#endif + +#ifdef MMA_SV_K +#define MMA_SV_K 16 +#endif + +// inner impl +template +__global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == SADataType::kInt8 || DTypeQK == SADataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(std::is_same::value || !use_inst_buffer, "use_inst_buffer only supports DTypeSVAccum as float"); + static_assert(std::is_same::value || std::is_same::value, "DTypeSVAccum must be float or half"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(!fuse_v_mean || std::is_same::value, "fuse_v_mean only supports half"); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == SADataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == SADataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + constexpr uint32_t V_SMEM_STRIDE = head_dim; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + else if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + +#pragma unroll + for (uint32_t iter = 1; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + // do not apply causal mask and out of bound mask for these iterations + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // check out of bound in the last iter + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + + } + + // TODO: thread block sync mdo state for num_warps_k > 0 + + normalize_d(RO, m, d); + + // save the result + // if (get_warp_idx_k() == 0) + // { + + // convert half to bfloat16 + if constexpr (std::is_same::value && std::is_same::value) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((nv_bfloat162*)RO[fq][fv])[0] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[0])); + ((nv_bfloat162*)RO[fq][fv])[1] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[1])); + ((nv_bfloat162*)RO[fq][fv])[2] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[2])); + ((nv_bfloat162*)RO[fq][fv])[3] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[3])); + } + } + } + + // add v_mean + if constexpr (fuse_v_mean) + { + DTypeOut2 v_mean[2]; + DTypeOut *V_mean_lane_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + lane_id % 4 * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + v_mean[0] = *((DTypeOut2*)(V_mean_lane_ptr + fv * 16)); + v_mean[1] = *((DTypeOut2*)(V_mean_lane_ptr + 8 + fv * 16)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + ((DTypeOut2*)RO[fq][fv])[0] = __hadd2(((DTypeOut2*)RO[fq][fv])[0], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[1] = __hadd2(((DTypeOut2*)RO[fq][fv])[1], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[2] = __hadd2(((DTypeOut2*)RO[fq][fv])[2], v_mean[1]); + ((DTypeOut2*)RO[fq][fv])[3] = __hadd2(((DTypeOut2*)RO[fq][fv])[3], v_mean[1]); + } + } + } + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + if constexpr (std::is_same::value) + { + // convert RO to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); + } + else if constexpr (std::is_same::value) + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); + } + } + + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + // ! permuted, make sure you know what you are doing + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + } + else if constexpr (std::is_same::value) + { + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[1]; + + // ! permuted, make sure you know what you are doing + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[2]; + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[3]; + } + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + // ! this only works for num_tiles_q = 2 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); + } + } + + // } +} + +// tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] +std::vector qk_int8_sv_f16_accum_f32_attn_fwd(paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT16); // TODO: there maybe some problem, for bf16 type + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.shape()[3]; + const int batch_size = query.shape()[0]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.strides()[1]; + stride_seq_k = key.strides()[1]; + stride_seq_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + + stride_h_q = query.strides()[2]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_h_o = output.strides()[2]; + } + else if (tensor_layout == 1) + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.strides()[2]; + stride_seq_k = key.strides()[2]; + stride_seq_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + + stride_h_q = query.strides()[1]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_h_o = output.strides()[1]; + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + paddle::Tensor lse = paddle::empty({1}); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + auto output_dtype = output.dtype(); // in [bfloat16 or float16] + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), // reinterpret_cast(reinterpret_cast(value.data())) + reinterpret_cast(output.data()), + (RETURN_LSE) ? lse.data() : nullptr, + query_scale.data(), + key_scale.data(), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f16_accum_f32_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f16_accum_f32_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f16_accum_f32_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f32_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f32_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f32_attn_InferDtype)); + +// tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] +// impl -> see sageattn.h file +std::vector qk_int8_sv_f16_accum_f16_attn_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT16); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.shape()[3]; + const int batch_size = query.shape()[0]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.strides()[1]; + stride_seq_k = key.strides()[1]; + stride_seq_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + + stride_h_q = query.strides()[2]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_h_o = output.strides()[2]; + } + else if (tensor_layout == 1) + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.strides()[2]; + stride_seq_k = key.strides()[2]; + stride_seq_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + + stride_h_q = query.strides()[1]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_h_o = output.strides()[1]; + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + paddle::Tensor lse = paddle::empty({1}); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), half, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f16_accum_f16_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f16_accum_f16_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f16_accum_f16_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f16_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f16_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f16_attn_InferDtype)); + + +std::vector qk_int8_sv_f16_accum_f16_attn_inst_buf_fwd(paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT16); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.shape()[3]; + const int batch_size = query.shape()[0]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.strides()[1]; + stride_seq_k = key.strides()[1]; + stride_seq_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + + stride_h_q = query.strides()[2]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_h_o = output.strides()[2]; + } + else if (tensor_layout == 1) + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.strides()[2]; + stride_seq_k = key.strides()[2]; + stride_seq_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + + stride_h_q = query.strides()[1]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_h_o = output.strides()[1]; + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = (HEAD_DIM == 64) ? 32 : 16; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, true, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f16_accum_f16_attn_inst_buf_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f16_accum_f16_attn_inst_buf_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f16_accum_f16_attn_inst_buf) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f16_attn_inst_buf_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f16_attn_inst_buf_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f16_attn_inst_buf_InferDtype)); + +std::vector qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fwd(paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_mean); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_mean); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT16); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_mean, 3); + + const int head_dim = query.shape()[3]; + const int batch_size = query.shape()[0]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.strides()[1]; + stride_seq_k = key.strides()[1]; + stride_seq_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + + stride_h_q = query.strides()[2]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_h_o = output.strides()[2]; + } + else if (tensor_layout == 1) + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.strides()[2]; + stride_seq_k = key.strides()[2]; + stride_seq_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + + stride_h_q = query.strides()[1]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_h_o = output.strides()[1]; + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + auto value_mean_dtype = value_mean.dtype(); + + PD_CHECK(value_mean_dtype == output_dtype, "value_mean and output must have the same dtype"); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_mean, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), half, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, true>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_mean.data()), + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_mean_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f16_accum_f16_fuse_v_mean_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_mean"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_InferDtype)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu deleted file mode 100644 index dd843da22d0b..000000000000 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu +++ /dev/null @@ -1,952 +0,0 @@ -#include -#include -#include "paddle/extension.h" - -#include "sageattn_utils.cuh" - -template -CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { - constexpr int smem_stride = BlockMinorSize * sizeof(T); - static_assert(sizeof(T) == 2 || sizeof(T) == 1); - static_assert(smem_stride == 32 || smem_stride == 64 || smem_stride == 128); - - CUtensorMap tma_map; - void* gmem_address = (void*)gmem_ptr; - uint64_t gmem_prob_shape[5] = {(uint64_t)d4, (uint64_t)d3, (uint64_t)d2, (uint64_t)d1, 1}; - uint64_t gmem_prob_stride[5] = {(uint64_t)stride3 * sizeof(T), (uint64_t)stride2 * sizeof(T), (uint64_t)stride1 * sizeof(T), 0, 0}; - uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; - uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; - - CUresult result = cuTensorMapEncodeTiled( - &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, - gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, - (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, - promotion_mode, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - - assert(result == CUDA_SUCCESS); - - return tma_map; -} - -__device__ __forceinline__ void init_barrier(uint64_t* bar, int thread_count) { - uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); - asm volatile ( - "mbarrier.init.shared::cta.b64 [%0], %1;\n" - :: "r"(bar_ptr), "r"(thread_count) - ); -} - -template -__device__ __forceinline__ void expect_bytes(uint64_t* bar) { - uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); - asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" - :: "r"(bar_ptr), "n"(bytes)); -} - -template -__device__ __forceinline__ void load_async_4D(T *dst, void const* const src_tma_map, uint64_t* bar, int s0, int s1, int s2, int s3) { - uint64_t tma_ptr = reinterpret_cast(src_tma_map); - uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); - uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); - - asm volatile ( - "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3, %4, %5, %6}], [%2];" - : - : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), - "r"(s0), "r"(s1), "r"(s2), "r"(s3) - : "memory" - ); -} - -template -__device__ __forceinline__ void store_async_4D(void const* dst_tma_map, T *src, int global_token_idx, int global_head_idx, int global_batch_idx) { - uint64_t tma_ptr = reinterpret_cast(dst_tma_map); - uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); - - asm volatile ( - "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" - " [%0, {%2, %3, %4, %5}], [%1];" - : - : "l"(tma_ptr), "r"(src_ptr), - "n"(0), "r"(global_token_idx), "r"(global_head_idx), "r"(global_batch_idx) - : "memory" - ); -} - -__device__ __forceinline__ void wait(uint64_t* bar, int kPhaseBit) { - uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); - asm volatile ( - "{\n" - ".reg .pred P1;\n" - "LAB_WAIT:\n" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" - "@P1 bra.uni DONE;\n" - "bra.uni LAB_WAIT;\n" - "DONE:\n" - "}\n" - :: "r"(mbar_ptr), - "r"(kPhaseBit) - ); -} - -template -__device__ __forceinline__ void arrive(uint64_t* bar) { - uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); - asm volatile ( - "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" - : - : "r"(mbar_ptr), "n"(count) - : "memory" - ); -} - -// -// ======= kernel impl ======= -// - -template -__global__ void qk_int8_sv_f8_attn_dsk_kernel(const __grid_constant__ CUtensorMap tensorMapQ, - const __grid_constant__ CUtensorMap tensorMapK, - const __grid_constant__ CUtensorMap tensorMapQ_pe, - const __grid_constant__ CUtensorMap tensorMapK_pe, - const __grid_constant__ CUtensorMap tensorMapV, - float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, - DTypeOut* O, uint32_t stride_bz_o, uint32_t stride_h_o, uint32_t stride_seq_o, - const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, - float sm_scale) -{ - static_assert(NUM_THREADS == 128); - static_assert(CTA_Q <= CTA_K); - - const uint32_t warp_idx = (threadIdx.x % 128) / 32; - const uint32_t lane_id = threadIdx.x % 32; - - constexpr uint32_t num_tiles_q = CTA_Q / 64; - constexpr uint32_t num_tiles_k = CTA_K / 16; - constexpr uint32_t num_tiles_qk_inner = head_dim / 32; - constexpr uint32_t num_tiles_qk_pe_inner = head_dim_pe / 32; - constexpr uint32_t num_tiles_v = head_dim / 16; - constexpr uint32_t num_tiles_pv_inner = CTA_K / 32; - - const uint32_t batch_id = blockIdx.z; - const uint32_t bx = blockIdx.x; - const uint32_t head_id = blockIdx.y; - const uint32_t num_qo_heads = gridDim.y; - const uint32_t kv_head_id = head_id / num_kv_groups; - - sm_scale *= math::log2e; - - extern __shared__ __align__(128) int8_t smem_[]; - - /* // original: - * int8_t *sQ = (int8_t*)smem_; - * int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); - * int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); - * half *sO = (half*)smem_; - */ - - int8_t *sQ = (int8_t*)smem_; // 0 - int8_t *sQ_pe = (int8_t*)(smem_ + CTA_Q * (head_dim) * sizeof(int8_t)); // 0 + head_dim - - int8_t *sK = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t)); // 0 + head_dim + pe - int8_t *sK_pe = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t) + CTA_K * (head_dim) * sizeof(int8_t)); // 0 + head_dim + pe + head_dim - int8_t *sV = (int8_t*)(smem_ + CTA_Q * (head_dim + head_dim_pe) * sizeof(int8_t) + CTA_K * (head_dim + head_dim_pe) * sizeof(int8_t)); - half *sO = (half*)smem_; - - int32_t RS[num_tiles_q][num_tiles_k][8]; - int32_t RS_pe[num_tiles_q][num_tiles_k][8]; - float RO[num_tiles_q][num_tiles_v][8]; - float m[num_tiles_q][2]; - float d[num_tiles_q][2]; - - uint32_t q_scale_idx, k_scale_idx; - - // scale shape: (b, h_qo, (qo_len + BLKQ - 1) // BLKQ) - if constexpr (Q_GRAN == QuantGranularity::kPerBlock) - { - const uint32_t num_block_q = gridDim.x; - q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; - } - else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_warp_block_q = gridDim.x * 4; - q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * 4 + warp_idx; - } - else if constexpr (Q_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_warp_block_q = gridDim.x * 4; - q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (4 * 8) + warp_idx * 8 + lane_id / 4; - } - - if constexpr (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) - { - const uint32_t num_block_k = div_ceil(kv_len, CTA_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; - } - else if constexpr (K_GRAN == QuantGranularity::kPerThread) - { - const uint32_t num_block_k = div_ceil(kv_len, CTA_K); - k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_block_k * 4) + (head_id / num_kv_groups) * (num_block_k * 4) + lane_id % 4; - } - - constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) ? 1 : 4; - - uint32_t Q_idx_lane_base = bx * CTA_Q + warp_idx * 16 + lane_id / 4; - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - m[fq][0] = -5000000.0f; - m[fq][1] = -5000000.0f; - d[fq][0] = 1.0f; - d[fq][1] = 1.0f; - } - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RO[fq][fv][k] = 0.0f; - } - } - } - - __shared__ __align__(8) uint64_t barrier_Q; - __shared__ __align__(8) uint64_t barrier_K; - __shared__ __align__(8) uint64_t barrier_Q_pe; - __shared__ __align__(8) uint64_t barrier_K_pe; - __shared__ __align__(8) uint64_t barrier_V; - - if (threadIdx.x == 0) { - init_barrier(&barrier_Q, 1); - init_barrier(&barrier_K, 1); - init_barrier(&barrier_Q_pe, 1); - init_barrier(&barrier_K_pe, 1); - init_barrier(&barrier_V, 1); - } - - __syncthreads(); - - // load Q, K, V - if (threadIdx.x == 0) - { - expect_bytes<(CTA_Q * (head_dim)) * sizeof(int8_t)>(&barrier_Q); - expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_K); - expect_bytes<(CTA_Q * (head_dim_pe)) * sizeof(int8_t)>(&barrier_Q_pe); - expect_bytes<(CTA_K * (head_dim_pe)) * sizeof(int8_t)>(&barrier_K_pe); - expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_V); - - load_async_4D(sQ, &tensorMapQ, &barrier_Q, 0, bx * CTA_Q, head_id, batch_id); - load_async_4D(sQ_pe, &tensorMapQ_pe, &barrier_Q_pe, 0, bx * CTA_Q, head_id, batch_id); - load_async_4D(sK, &tensorMapK, &barrier_K, 0, 0, kv_head_id, batch_id); - load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, 0, kv_head_id, batch_id); - load_async_4D(sV, &tensorMapV, &barrier_V, 0, 0, kv_head_id, batch_id); - } - - float q_scale = Q_scale[q_scale_idx]; - float original_sm_scale = sm_scale; - - // wait for Q - wait(&barrier_Q, 0); - wait(&barrier_Q_pe, 0); - - const uint32_t num_iterations = div_ceil( - mask_mode == MaskMode::kCausal - ? min(kv_len, (bx + 1) * CTA_Q) - : kv_len, - CTA_K); - - int p = 1; - for (uint32_t iter = 1; iter < num_iterations; iter++) - { - p ^= 1; - - float dequant_scale = q_scale * K_scale[k_scale_idx + (iter - 1) * k_scale_advance_offset]; - sm_scale = original_sm_scale * dequant_scale; - - // wait for K - wait(&barrier_K, p); - wait(&barrier_K_pe, p); - - // compute QK^T - wgmma::warpgroup_arrive(); -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - int8_t *sQ_local = sQ + fq * 64 * head_dim; - int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; - wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); - wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); // add one line -#pragma unroll - for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) - { - wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); - if (k_it < num_tiles_qk_pe_inner) { - wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line - } - } - } - wgmma::warpgroup_commit_batch(); - wgmma::warpgroup_wait<0>(); - - // load K - if (threadIdx.x == 0) - { - expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); - expect_bytes<(CTA_K * head_dim_pe) * sizeof(int8_t)>(&barrier_K_pe); // add one line - load_async_4D(sK, &tensorMapK, &barrier_K, 0, iter * CTA_K, kv_head_id, batch_id); - load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, iter * CTA_K, kv_head_id, batch_id); // add one line - } - - // convert RS to float - float RS_f32[num_tiles_q][num_tiles_k][8]; -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]); // add one line - } - } - } - - update_mdo(RS_f32, RO, m, d, sm_scale); - - // accumulate d on thread basis -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unrol - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { - d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); - d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); - } - } - - uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; - RS_32_to_8(RS_f32, RS_f8); - - // wait for V - wait(&barrier_V, p); - - float RO_temp[num_tiles_q][num_tiles_v][8]; - wgmma::warpgroup_arrive(); -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); -#pragma unroll - for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) - { - wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); - } - } - - wgmma::warpgroup_commit_batch(); - wgmma::warpgroup_wait<0>(); - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RO[fq][fv][k] += RO_temp[fq][fv][k]; - } - } - } - - // load V - if (threadIdx.x == 0) - { - expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); - load_async_4D(sV, &tensorMapV, &barrier_V, iter * CTA_K, 0, kv_head_id, batch_id); - } - } - - { - p ^= 1; - - float dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; - sm_scale = original_sm_scale; - - // wait for K - wait(&barrier_K, p); - wait(&barrier_K_pe, p); - - // compute QK^T - wgmma::warpgroup_arrive(); -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - int8_t *sQ_local = sQ + fq * 64 * head_dim; - int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; - wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); - wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); -#pragma unroll - for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) - { - wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); - if (k_it < num_tiles_qk_pe_inner) { - wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line - } - } - } - wgmma::warpgroup_commit_batch(); - wgmma::warpgroup_wait<0>(); - - // convert RS to float - float RS_f32[num_tiles_q][num_tiles_k][8]; -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]) * dequant_scale; - } - } - } - - // masking -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - const uint32_t q_idx = Q_idx_lane_base + fq * 64 + 8 * ((k % 4) / 2); - const uint32_t k_idx = (num_iterations - 1) * CTA_K + fk * 16 + 2 * (lane_id % 4) + 8 * (k / 4) + k % 2; - - bool is_out_of_bounds; - - if constexpr (mask_mode == MaskMode::kCausal) - { - is_out_of_bounds = (k_idx > q_idx) || (k_idx >= kv_len); - } - else - { - is_out_of_bounds = (k_idx >= kv_len); - } - - if (is_out_of_bounds) - { - RS_f32[fq][fk][k] = -5000000.0f; - } - } - } - } - - update_mdo(RS_f32, RO, m, d, sm_scale); - - // accumulate d on thread basis -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unrol - for (uint32_t fk = 0; fk < num_tiles_k; fk++) - { - d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); - d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); - } - } - - uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; - RS_32_to_8(RS_f32, RS_f8); - - // wait for V - wait(&barrier_V, p); - - float RO_temp[num_tiles_q][num_tiles_v][8]; - wgmma::warpgroup_arrive(); -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); -#pragma unroll - for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) - { - wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); - } - } - - wgmma::warpgroup_commit_batch(); - wgmma::warpgroup_wait<0>(); - -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { -#pragma unroll - for (uint32_t k = 0; k < 8; k++) - { - RO[fq][fv][k] += RO_temp[fq][fv][k]; - } - } - } - } - - normalize_d(RO, m, d); - - if constexpr (fuse_v_scale) - { - float v_scale[4]; - float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; - #pragma unroll - for (uint32_t fv = 0; fv < num_tiles_v; fv++) - { - ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); - ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); - - #pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { - RO[fq][fv][0] *= v_scale[0]; - RO[fq][fv][1] *= v_scale[1]; - RO[fq][fv][2] *= v_scale[0]; - RO[fq][fv][3] *= v_scale[1]; - RO[fq][fv][4] *= v_scale[2]; - RO[fq][fv][5] *= v_scale[3]; - RO[fq][fv][6] *= v_scale[2]; - RO[fq][fv][7] *= v_scale[3]; - } - } - } - - DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + warp_idx * 16 + (lane_id / 4)) * stride_seq_o + (lane_id % 4) * 2 ; -#pragma unroll - for (uint32_t fq = 0; fq < num_tiles_q; fq++) - { -#pragma unroll - for (uint32_t fv = 0; fv < head_dim/16; fv++) - { - if (Q_idx_lane_base + fq * 64 < qo_len) - { - if constexpr (std::is_same::value) - { - ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[0]); - ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[2]); - } - else - { - ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[0]); - ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[2]); - } - } - - if (Q_idx_lane_base + fq * 64 + 8 < qo_len) - { - if constexpr (std::is_same::value) - { - ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[1]); - ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[3]); - } - else - { - ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[1]); - ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[3]); - } - } - } - } -} - -std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& query_pe, - paddle::Tensor& key_pe, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse) -{ - CHECK_CUDA(query); - CHECK_CUDA(key); - CHECK_CUDA(query_pe); - CHECK_CUDA(key_pe); - CHECK_CUDA(value); - CHECK_CUDA(output); - CHECK_CUDA(query_scale); - CHECK_CUDA(key_scale); - - CHECK_LASTDIM_CONTIGUOUS(query); - CHECK_LASTDIM_CONTIGUOUS(key); - CHECK_LASTDIM_CONTIGUOUS(query_pe); - CHECK_LASTDIM_CONTIGUOUS(key_pe); - CHECK_LASTDIM_CONTIGUOUS(value); - CHECK_LASTDIM_CONTIGUOUS(output); - CHECK_CONTIGUOUS(query_scale); - CHECK_CONTIGUOUS(key_scale); - - CHECK_DTYPE(query, paddle::DataType::INT8); - CHECK_DTYPE(key, paddle::DataType::INT8); - CHECK_DTYPE(query_pe, paddle::DataType::INT8); - CHECK_DTYPE(key_pe, paddle::DataType::INT8); - CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); - CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); - CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); - - CHECK_DIMS(query, 4); - CHECK_DIMS(key, 4); - CHECK_DIMS(query_pe, 4); - CHECK_DIMS(key_pe, 4); - CHECK_DIMS(value, 4); - CHECK_DIMS(output, 4); - CHECK_DIMS(query_scale, 3); - CHECK_DIMS(key_scale, 3); - - const int batch_size = query.shape()[0]; - const int head_dim = query.shape()[3]; // 现在query是正常的128, 多出来的64在query_pe里面,所以这样做没什么问题 - - int stride_bz_q = query.strides()[0]; - int stride_bz_k = key.strides()[0]; - int stride_bz_v = value.strides()[0]; - int stride_bz_o = output.strides()[0]; - - int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; - int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; - - assert(value.shape()[0] == batch_size); - - if (tensor_layout == 0) - { - qo_len = query.shape()[1]; - kv_len = key.shape()[1]; - num_qo_heads = query.shape()[2]; - num_kv_heads = key.shape()[2]; - - stride_seq_q = query.strides()[1]; - stride_h_q = query.strides()[2]; - stride_seq_k = key.strides()[1]; - stride_h_k = key.strides()[2]; - stride_h_v = value.strides()[2]; - stride_d_v = value.strides()[1]; - stride_seq_o = output.strides()[1]; - stride_h_o = output.strides()[2]; - - CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); - CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); - assert(value.shape()[1] == head_dim); - assert(value.shape()[2] == num_kv_heads); - } - else - { - qo_len = query.shape()[2]; - kv_len = key.shape()[2]; - num_qo_heads = query.shape()[1]; - num_kv_heads = key.shape()[1]; - - stride_seq_q = query.strides()[2]; - stride_h_q = query.strides()[1]; - stride_seq_k = key.strides()[2]; - stride_h_k = key.strides()[1]; - stride_h_v = value.strides()[1]; - stride_d_v = value.strides()[2]; - stride_seq_o = output.strides()[2]; - stride_h_o = output.strides()[1]; - - CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); - CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); - assert(value.shape()[2] == head_dim); - assert(value.shape()[1] == num_kv_heads); - } - - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; - throw std::invalid_argument(err_msg.str()); - } - - paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); - if (return_lse) - { - lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); - } - - const int num_kv_groups = num_qo_heads / num_kv_heads; - - auto output_type = output.dtype(); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { - DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { - DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_type, DTypeOut, { - constexpr int CTA_Q = 64; - constexpr int CTA_K = 128; - constexpr int NUM_THREADS = 128; - constexpr int HEAD_DIM_PE = 64; - - constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; - - assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); - - if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); - } - else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); - } - else - { - static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); - } - - CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); - CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); - CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q, stride_h_q, stride_seq_q); - CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k, stride_h_k, stride_seq_k); - - CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); - - auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, false>; - size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); - sMemSize += CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); // add extra space for qk pe - cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); - - dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); - kernel<<>>( - tma_map_Q, - tma_map_K, - tma_map_Q_pe, - tma_map_K_pe, - tma_map_V, - reinterpret_cast(query_scale.data()), - reinterpret_cast(key_scale.data()), - nullptr, - reinterpret_cast(output.data()), - stride_bz_o, stride_h_o, stride_seq_o, - qo_len, kv_len, num_kv_groups, sm_scale); - }); - }); - }); - }); - - return {lse}; -} - -std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_fwd( - paddle::Tensor& query, - paddle::Tensor& key, - paddle::Tensor& query_pe, - paddle::Tensor& key_pe, - paddle::Tensor& value, - paddle::Tensor& output, - paddle::Tensor& query_scale, - paddle::Tensor& key_scale, - paddle::Tensor& value_scale, - int tensor_layout, - int is_causal, - int qk_quant_gran, - float sm_scale, - int return_lse) -{ - CHECK_CUDA(query); - CHECK_CUDA(key); - CHECK_CUDA(query_pe); - CHECK_CUDA(key_pe); - CHECK_CUDA(value); - CHECK_CUDA(output); - CHECK_CUDA(query_scale); - CHECK_CUDA(key_scale); - CHECK_CUDA(value_scale); - - CHECK_LASTDIM_CONTIGUOUS(query); - CHECK_LASTDIM_CONTIGUOUS(key); - CHECK_LASTDIM_CONTIGUOUS(query_pe); - CHECK_LASTDIM_CONTIGUOUS(key_pe); - CHECK_LASTDIM_CONTIGUOUS(value); - CHECK_LASTDIM_CONTIGUOUS(output); - CHECK_CONTIGUOUS(query_scale); - CHECK_CONTIGUOUS(key_scale); - CHECK_CONTIGUOUS(value_scale); - - CHECK_DTYPE(query, paddle::DataType::INT8); - CHECK_DTYPE(key, paddle::DataType::INT8); - CHECK_DTYPE(query_pe, paddle::DataType::INT8); - CHECK_DTYPE(key_pe, paddle::DataType::INT8); - CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); - CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); - CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); - CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); - - CHECK_DIMS(query, 4); - CHECK_DIMS(key, 4); - CHECK_DIMS(query_pe, 4); - CHECK_DIMS(key_pe, 4); - CHECK_DIMS(value, 4); - CHECK_DIMS(output, 4); - CHECK_DIMS(query_scale, 3); - CHECK_DIMS(key_scale, 3); - CHECK_DIMS(value_scale, 3); - - const int batch_size = query.shape()[0]; - const int head_dim = query.shape()[3]; - - int stride_bz_q = query.strides()[0]; - int stride_bz_k = key.strides()[0]; - int stride_bz_v = value.strides()[0]; - int stride_bz_o = output.strides()[0]; - - int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; - int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; - - assert(value.shape()[0] == batch_size); - - if (tensor_layout == 0) - { - qo_len = query.shape()[1]; - kv_len = key.shape()[1]; - num_qo_heads = query.shape()[2]; - num_kv_heads = key.shape()[2]; - - stride_seq_q = query.strides()[1]; - stride_h_q = query.strides()[2]; - stride_seq_k = key.strides()[1]; - stride_h_k = key.strides()[2]; - stride_h_v = value.strides()[2]; - stride_d_v = value.strides()[1]; - stride_seq_o = output.strides()[1]; - stride_h_o = output.strides()[2]; - - CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); - CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); - - assert(value.shape()[1] == head_dim); - assert(value.shape()[2] == num_kv_heads); - } - else - { - qo_len = query.shape()[2]; - kv_len = key.shape()[2]; - num_qo_heads = query.shape()[1]; - num_kv_heads = key.shape()[1]; - - stride_seq_q = query.strides()[2]; - stride_h_q = query.strides()[1]; - stride_seq_k = key.strides()[2]; - stride_h_k = key.strides()[1]; - stride_h_v = value.strides()[1]; - stride_d_v = value.strides()[2]; - stride_seq_o = output.strides()[2]; - stride_h_o = output.strides()[1]; - - CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); - CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); - assert(value.shape()[2] == head_dim); - assert(value.shape()[1] == num_kv_heads); - } - - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; - throw std::invalid_argument(err_msg.str()); - } - - paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); - if (return_lse) - { - lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); - } - - const int num_kv_groups = num_qo_heads / num_kv_heads; - - auto output_dtype = output.dtype(); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { - DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { - DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { - constexpr int CTA_Q = 64; - constexpr int CTA_K = 128; - constexpr int NUM_THREADS = 128; - constexpr int HEAD_DIM_PE = 64; - - constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; - - assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); - - if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); - } - else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) - { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); - } - else - { - static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); - } - - CHECK_SHAPE(value_scale, batch_size, num_kv_heads, HEAD_DIM); - CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); - CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q, stride_h_q, stride_seq_q); - CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); - CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k, stride_h_k, stride_seq_k); - - CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); - - auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, true>; - size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); - sMemSize += CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); - cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); - - dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); - kernel<<>>( - tma_map_Q, - tma_map_K, - tma_map_Q_pe, - tma_map_K_pe, - tma_map_V, - reinterpret_cast(query_scale.data()), - reinterpret_cast(key_scale.data()), - reinterpret_cast(value_scale.data()), - reinterpret_cast(output.data()), - stride_bz_o, stride_h_o, stride_seq_o, - qo_len, kv_len, num_kv_groups, sm_scale); - }); - }); - }); - }); - - return {lse}; -} \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu similarity index 67% rename from csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu rename to csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu index eca067e1aef0..b5853899b574 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu @@ -609,12 +609,12 @@ __global__ void qk_int_sv_f8_attn_kernel(int8_t *__restrict__ Q, int8_t *__restr } } - ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; - ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O) + 1); - ((int32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[2]; - ((int32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[2]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; } else if constexpr (std::is_same::value) { @@ -671,6 +671,440 @@ __global__ void qk_int_sv_f8_attn_kernel(int8_t *__restrict__ Q, int8_t *__restr } } // kernel impl end +std::vector qk_int8_sv_f8_accum_f32_attn_fwd(paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + // TODO: how to check fp8 data type? + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.shape()[0] == batch_size); + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f8_accum_f32_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_InferDtype)); + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_fwd(paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + // TODO: how to check fp8 data type? + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.shape()[0] == batch_size); + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data(), + key.data(), + reinterpret_cast(value.data()), + reinterpret_cast(output.data()), + (RETURN_LSE) ? reinterpret_cast(lse.data()) : nullptr, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f8_accum_f32_attn_inst_buf_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn_inst_buf) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale",}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_inst_buf_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_InferDtype)); + // impl -> see sageattn.h file std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd( paddle::Tensor& query, @@ -708,7 +1142,6 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_att CHECK_DTYPE(query, paddle::DataType::INT8); CHECK_DTYPE(key, paddle::DataType::INT8); // TODO: how to check fp8 data type? - // CHECK_DTYPE(value, torch::kHalf); CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); @@ -798,10 +1231,10 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_att DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { - constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; - constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; - constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; - constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; assert(value.shape()[0] == batch_size); assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); @@ -810,13 +1243,13 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_att if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); } else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); } else { @@ -864,6 +1297,50 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_att return {lse}; } +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape, + std::vector value_mean_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype, + paddle::DataType H_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale", "value_mean"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "out8", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}, {"value_mean", "out8"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_InferDtype)); + std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( paddle::Tensor& query, paddle::Tensor& key, @@ -897,7 +1374,6 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( CHECK_DTYPE(query, paddle::DataType::INT8); CHECK_DTYPE(key, paddle::DataType::INT8); // TODO: how to check fp8 data type? - // CHECK_DTYPE(value, torch::kHalf); CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); @@ -986,10 +1462,10 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { - constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; - constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; - constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; - constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; assert(value.shape()[0] == batch_size); assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); @@ -998,13 +1474,13 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); } else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); } else { @@ -1051,6 +1527,48 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd( return {lse}; } +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_InferDtype)); + std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd( paddle::Tensor& query, paddle::Tensor& key, @@ -1084,7 +1602,6 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_s CHECK_DTYPE(query, paddle::DataType::INT8); CHECK_DTYPE(key, paddle::DataType::INT8); // TODO: how to check fp8 data type? - // CHECK_DTYPE(value, torch::kHalf); CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); @@ -1173,10 +1690,10 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_s DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { - constexpr int CTA_Q = (HEAD_DIM == 256) ? 64 : 128; - constexpr int CTA_K = (HEAD_DIM == 256) ? 64 : 64; - constexpr int WARP_Q = (HEAD_DIM == 256) ? 16 : 32; - constexpr int WARP_K = (HEAD_DIM == 256) ? 64 : 64; + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; assert(value.shape()[0] == batch_size); assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); @@ -1185,13 +1702,13 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_s if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q))); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K))); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); } else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) { - CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8)); - CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4)); + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); } else { @@ -1237,3 +1754,45 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_s return {lse}; } + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89_InferDtype)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu index 84083d543afd..0ee91f10a247 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu @@ -702,6 +702,46 @@ std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd( return {lse}; } +std::vector> qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_InferDtype)); + std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd( paddle::Tensor& query, paddle::Tensor& key, @@ -875,4 +915,46 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_s }); return {lse}; -} \ No newline at end of file +} + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90) + .Inputs({"query", "key", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out1", "out2", "out3", "out4", "out5", "out6", "out7", "lse"}) + .SetInplaceMap({{"query", "out1"}, {"key", "out2"}, {"value", "out3"}, {"output", "out4"}, {"query_scale", "out5"}, {"key_scale", "out6"}, {"value_scale", "out7"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_InferDtype)); \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh index 49e0fb96893a..3e408208b0d1 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh +++ b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -22,6 +23,8 @@ #define S_FP8_OFFSET_EXP 6680.8477f #define S_FP8_OFFSET_EXP_INV 0.0022326917f +#define div_ceil(M, N) (((M) + (N)-1) / (N)) + #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) #define FP8_CAST_ENABLED @@ -49,6 +52,26 @@ throw std::invalid_argument(err_msg.str()); \ } +// add new support to HEAD_DIM = 192 for deepseek +#define DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, ...) \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + __VA_ARGS__ \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + __VA_ARGS__ \ + } else if (head_dim == 192) { \ + constexpr int HEAD_DIM = 192; \ + __VA_ARGS__ \ + } else if (head_dim == 256) { \ + constexpr int HEAD_DIM = 256; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head dim: " << int(head_dim); \ + throw std::invalid_argument(err_msg.str()); \ + } + #define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ if (is_causal == 1) { \ constexpr bool IS_CAUSAL = true; \ @@ -1079,6 +1102,10 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, b } // namespace cp_async +#ifndef USHORT_TYPE +#define USHORT_TYPE +typedef unsigned short ushort; +#endif // namespace math // math operations using ptx @@ -1516,10 +1543,6 @@ enum class ComputeUnit { kCudaCore, }; -inline __device__ __host__ size_t div_ceil(size_t a, size_t b) { - return (a % b != 0) ? (a / b + 1) : (a / b); -} - __device__ __forceinline__ uint32_t get_warp_id() { return threadIdx.y; @@ -1833,11 +1856,24 @@ __device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAc max(RS[fq][fk][k * 2 + 4], RS[fq][fk][k * 2 + 5])); m_temp = max(m_temp, m_local); } - // exchange element with the 4 threads in the row + if constexpr (!fuse_scale) { - m_temp *= sm_scale; + if constexpr (exp_offset) + { + m_temp = fmaf(m_temp, sm_scale, -S_FP8_OFFSET); + } + else + { + m_temp *= sm_scale; + } + } + else if constexpr (exp_offset) + { + m_temp += (-S_FP8_OFFSET); } + + // exchange element with the 4 threads in the row m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x1)); // 0 exchange with 1, 2 exchange with 3 m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x2)); // 0 exchange with 2, 1 exchange with 3 @@ -1848,7 +1884,11 @@ __device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAc // update denominator d[fq][k] *= o_scale; - half2 o_scale2 = __floats2half2_rn(o_scale, o_scale); + half2 o_scale2; + if constexpr (use_half_o_scale) + { + o_scale2 = __floats2half2_rn(o_scale, o_scale); + } // update RO #pragma unroll @@ -1880,10 +1920,6 @@ __device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAc // raise RS to exponent float negative_m = -m[fq][k]; - if constexpr (exp_offset) - { - negative_m += S_FP8_OFFSET; // times 400 to achieve smaller quantization error of fp8 S - } #pragma unroll for (uint32_t fk = 0; fk < num_tiles_k; fk++) { diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 12a36c1e306b..9e45711b6fb7 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -136,7 +136,6 @@ def get_gencode_flags(): cc = get_sm_version() cuda_version = float(paddle.version.cuda()) -cuda_version = 12.4 if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] @@ -164,15 +163,26 @@ def get_gencode_flags(): "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", ] -if cc >= 89 and cuda_version >= 12.4: +if cc >= 80 and cuda_version >= 12.4: nvcc_compile_args += [ "-std=c++17", "--use_fast_math", "--threads=8", "-D_GLIBCXX_USE_CXX11_ABI=1", ] - sources += find_end_files("./gpu/sage_attn_kernels", ".cu") - sources += ["./gpu/sage_attn_kernels/sageattn.cc"] + if cc >= 80: + sources += [ + "./gpu/sage_attn_kernels/sageattn_fused.cu", + "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu" + ] + if cc >= 89: + sources += [ + "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu" + ] + if cc >= 90: + sources += [ + "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu" + ] if cc >= 90 and cuda_version >= 12.0: nvcc_compile_args += ["-DNDEBUG"] From 787011bf47b7ce05056e16bfbf163e96687b8084 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 25 Feb 2025 14:37:51 +0800 Subject: [PATCH 05/18] update setup_cuda.py --- csrc/setup_cuda.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 9e45711b6fb7..20983c909250 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -175,14 +175,17 @@ def get_gencode_flags(): "./gpu/sage_attn_kernels/sageattn_fused.cu", "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu" ] + nvcc_compile_args += ["-gencode", f"arch=compute_80,code=compute_80"] if cc >= 89: sources += [ "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu" ] + nvcc_compile_args += ["-gencode", f"arch=compute_89,code=compute_89"] if cc >= 90: sources += [ "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu" ] + nvcc_compile_args += ["-gencode", f"arch=compute_90a,code=compute_90a"] if cc >= 90 and cuda_version >= 12.0: nvcc_compile_args += ["-DNDEBUG"] From 1b9e4e0b3ec5d889b47750ec56fdc0446c316f8e Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 25 Feb 2025 17:20:35 +0800 Subject: [PATCH 06/18] update dsk MLA kernel --- csrc/gpu/sage_attn_kernels/core.py | 340 ++++++ csrc/gpu/sage_attn_kernels/quant.py | 130 ++ .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 1066 +++++++++++++++++ csrc/setup_cuda.py | 11 +- 4 files changed, 1542 insertions(+), 5 deletions(-) create mode 100644 csrc/gpu/sage_attn_kernels/core.py create mode 100644 csrc/gpu/sage_attn_kernels/quant.py create mode 100644 csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu diff --git a/csrc/gpu/sage_attn_kernels/core.py b/csrc/gpu/sage_attn_kernels/core.py new file mode 100644 index 000000000000..d13557b9aa73 --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/core.py @@ -0,0 +1,340 @@ +import paddle +import paddlenlp_ops + +from typing import Optional, Any +import warnings + +from .quant import per_channel_fp8 +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean + + +def sageattn_qk_int8_pv_fp16_cuda_sm80( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_warp", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> paddle.Tensor: + dtype = q.dtype + assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.shape[-1] + + if head_dim_og < 64: + q = paddle.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = paddle.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = paddle.mean(k, axis=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + else: + lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.shape[-1] == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64) + + o = paddle.empty(q.shape, dtype=dtype) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == 'fp32': + v = v.to(paddle.float16) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + else: + v = v.to(paddle.float16) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(paddle.float16) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda_sm89( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tensor_layout: str = "NHD", + is_causal: bool = False, + qk_quant_gran: str = "per_warp", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +): + dtype = q.dtype + assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_causal = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.shape[-1] + + if head_dim_og < 64: + q = paddle.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = paddle.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = paddle.mean(k, axis=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + else: + lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) + else: + km = None + + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout) + + o = paddle.empty(q.shape, dtype=dtype) + if pv_accum_dtype == 'fp32+fp32' and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v) + + if pv_accum_dtype == "fp32": + if smooth_v: + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + else: + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + elif pv_accum_dtype == "fp32+fp32": + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + + o = o[..., :head_dim_og] + + if return_lse: + return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_warp", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> paddle.Tensor: + dtype = q.dtype + assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_causal = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.shape[-1] + + if head_dim_og < 64: + q = paddle.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = paddle.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = paddle.mean(k, axis=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + else: + lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128) + + o = paddle.empty(q.shape, dtype=dtype) + + kv_len = k.shape[seq_dim] + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2) + else: + v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1) + + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + + o = o[..., :head_dim_og] + + if return_lse: + return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_warp", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> paddle.Tensor: + dtype = q.dtype + assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_causal = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.shape[-1] + + pad_dim_tgt = 256 + + if head_dim_og < 64: + q = paddle.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = paddle.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128 and head_dim_og < pad_dim_tgt: + q = paddle.nn.functional.pad(q, (0, pad_dim_tgt - head_dim_og)) + k = paddle.nn.functional.pad(k, (0, pad_dim_tgt - head_dim_og)) + elif head_dim_og > pad_dim_tgt: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = paddle.mean(k, axis=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + else: + lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128) + + o = paddle.empty(v.shape, dtype=dtype) + + kv_len = k.shape[seq_dim] + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2) + else: + v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + if pad_dim_tgt == 256: + q_int8_nope, q_int8_pe, _ = paddle.split(q_int8, [128, 64, 64], axis=-1) + k_int8_nope, k_int8_pe, _ = paddle.split(k_int8, [128, 64, 64], axis=-1) + else: + q_int8_nope, q_int8_pe = paddle.split(q_int8, [128, 64], axis=-1) + k_int8_nope, k_int8_pe = paddle.split(k_int8, [128, 64], axis=-1) + + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90(q_int8_nope, k_int8_nope, q_int8_pe, k_int8_pe, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + + head_dim_og = v.shape[-1] + o = o[..., :head_dim_og] + + if return_lse: + return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 + else: + return o diff --git a/csrc/gpu/sage_attn_kernels/quant.py b/csrc/gpu/sage_attn_kernels/quant.py new file mode 100644 index 000000000000..9124810bc44b --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/quant.py @@ -0,0 +1,130 @@ +import paddle +import paddlenlp_ops + +from typing import Optional + +def per_block_int8( + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int =128, + BLKK: int =64, + sm_scale: Optional[float] = None, + tensor_layout: str ="HND" +): + q_int8 = paddle.empty(q.shape, dtype=paddle.int8) + k_int8 = paddle.empty(k.shape, dtype=paddle.int8) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = paddle.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), dtype=paddle.float32) + k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + paddlenlp_ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) + else: + paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int =128, + WARPQ: int =32, + BLKK: int =64, + tensor_layout: str ="HND" +): + q_int8 = paddle.empty(shape=q.shape, dtype=paddle.int8) + k_int8 = paddle.empty(shape=k.shape, dtype=paddle.int8) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = paddle.empty((b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), dtype=paddle.float32) + k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) + + paddlenlp_ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) + else: + paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_channel_fp8( + v: paddle.Tensor, + tensor_layout: str ="NHD", + scale_max: float = 448.0, + smooth_v: bool = True +): + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = paddle.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = paddle.empty((b, head_dim, h_kv, padded_len), dtype=v.dtype) + paddlenlp_ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = paddle.empty(v_transposed_permutted.shape, dtype=paddle.float8_e4m3fn) + + v_scale = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) + vm = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) + + if smooth_v: + paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, kv_len, scale_max, _tensor_layout) + return v_fp8, v_scale, vm + else: + paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout) + return v_fp8, v_scale, None + + +def sub_mean( + v: paddle.Tensor, + tensor_layout: str ="HND" +): + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = paddle.empty(v.shape, dtype=paddle.float16) + + # subtract mean and store the result as fp16 + paddlenlp_ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm \ No newline at end of file diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu new file mode 100644 index 000000000000..7d16f911396a --- /dev/null +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -0,0 +1,1066 @@ +#include + +#include "paddle/extension.h" + +#include "sageattn_utils.cuh" + +template +CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { + constexpr int smem_stride = BlockMinorSize * sizeof(T); + static_assert(sizeof(T) == 2 || sizeof(T) == 1); + static_assert(smem_stride == 32 || smem_stride == 64 || smem_stride == 128); + + CUtensorMap tma_map; + void* gmem_address = (void*)gmem_ptr; + uint64_t gmem_prob_shape[5] = {(uint64_t)d4, (uint64_t)d3, (uint64_t)d2, (uint64_t)d1, 1}; + uint64_t gmem_prob_stride[5] = {(uint64_t)stride3 * sizeof(T), (uint64_t)stride2 * sizeof(T), (uint64_t)stride1 * sizeof(T), 0, 0}; + uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; + uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; + + CUresult result = cuTensorMapEncodeTiled( + &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, + gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, + (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, + promotion_mode, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + assert(result == CUDA_SUCCESS); + + return tma_map; +} + +__device__ __forceinline__ void init_barrier(uint64_t* bar, int thread_count) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count) + ); +} + +template +__device__ __forceinline__ void expect_bytes(uint64_t* bar) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "n"(bytes)); +} + +template +__device__ __forceinline__ void load_async_4D(T *dst, void const* const src_tma_map, uint64_t* bar, int s0, int s1, int s2, int s3) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(s0), "r"(s1), "r"(s2), "r"(s3) + : "memory" + ); +} + +template +__device__ __forceinline__ void store_async_4D(void const* dst_tma_map, T *src, int global_token_idx, int global_head_idx, int global_batch_idx) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(global_token_idx), "r"(global_head_idx), "r"(global_batch_idx) + : "memory" + ); +} + +__device__ __forceinline__ void wait(uint64_t* bar, int kPhaseBit) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +template +__device__ __forceinline__ void arrive(uint64_t* bar) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "n"(count) + : "memory" + ); +} + +// +// ======= kernel impl ======= +// + +template +__global__ void qk_int8_sv_f8_attn_dsk_kernel(const __grid_constant__ CUtensorMap tensorMapQ, + const __grid_constant__ CUtensorMap tensorMapK, + const __grid_constant__ CUtensorMap tensorMapQ_pe, + const __grid_constant__ CUtensorMap tensorMapK_pe, + const __grid_constant__ CUtensorMap tensorMapV, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, + DTypeOut* O, uint32_t stride_bz_o, uint32_t stride_h_o, uint32_t stride_seq_o, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + float sm_scale) +{ + static_assert(NUM_THREADS == 128); + static_assert(CTA_Q <= CTA_K); + + const uint32_t warp_idx = (threadIdx.x % 128) / 32; + const uint32_t lane_id = threadIdx.x % 32; + + constexpr uint32_t num_tiles_q = CTA_Q / 64; + constexpr uint32_t num_tiles_k = CTA_K / 16; + constexpr uint32_t num_tiles_qk_inner = head_dim / 32; + constexpr uint32_t num_tiles_qk_pe_inner = head_dim_pe / 32; + constexpr uint32_t num_tiles_v = head_dim / 16; + constexpr uint32_t num_tiles_pv_inner = CTA_K / 32; + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t head_id = blockIdx.y; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t kv_head_id = head_id / num_kv_groups; + + sm_scale *= math::log2e; + + extern __shared__ __align__(128) int8_t smem_[]; + + /* // original: + * int8_t *sQ = (int8_t*)smem_; + * int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); + * int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + * half *sO = (half*)smem_; + */ + + int8_t *sQ = (int8_t*)smem_; + int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); + int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + int8_t *sQ_pe = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + int8_t *sK_pe = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t) + CTA_Q * head_dim_pe * sizeof(int8_t)); + + half *sO = (half*)smem_; + + int32_t RS[num_tiles_q][num_tiles_k][8]; + int32_t RS_pe[num_tiles_q][num_tiles_k][8]; + float RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; + float d[num_tiles_q][2]; + + uint32_t q_scale_idx, k_scale_idx; + + // scale shape: (b, h_qo, (qo_len + BLKQ - 1) // BLKQ) + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * 4 + warp_idx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (4 * 8) + warp_idx * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_block_k * 4) + (head_id / num_kv_groups) * (num_block_k * 4) + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) ? 1 : 4; + + uint32_t Q_idx_lane_base = bx * CTA_Q + warp_idx * 16 + lane_id / 4; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + m[fq][0] = -5000000.0f; + m[fq][1] = -5000000.0f; + d[fq][0] = 1.0f; + d[fq][1] = 1.0f; + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + } + + __shared__ __align__(8) uint64_t barrier_Q; + __shared__ __align__(8) uint64_t barrier_K; + __shared__ __align__(8) uint64_t barrier_Q_pe; + __shared__ __align__(8) uint64_t barrier_K_pe; + __shared__ __align__(8) uint64_t barrier_V; + + if (threadIdx.x == 0) { + init_barrier(&barrier_Q, 1); + init_barrier(&barrier_K, 1); + init_barrier(&barrier_Q_pe, 1); + init_barrier(&barrier_K_pe, 1); + init_barrier(&barrier_V, 1); + } + + __syncthreads(); + + // load Q, K, V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_Q * (head_dim)) * sizeof(int8_t)>(&barrier_Q); + expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_Q * (head_dim_pe)) * sizeof(int8_t)>(&barrier_Q_pe); + expect_bytes<(CTA_K * (head_dim_pe)) * sizeof(int8_t)>(&barrier_K_pe); + expect_bytes<(CTA_K * (head_dim)) * sizeof(int8_t)>(&barrier_V); + + load_async_4D(sQ, &tensorMapQ, &barrier_Q, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, 0, kv_head_id, batch_id); + load_async_4D(sV, &tensorMapV, &barrier_V, 0, 0, kv_head_id, batch_id); + load_async_4D(sQ_pe, &tensorMapQ_pe, &barrier_Q_pe, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, 0, kv_head_id, batch_id); + } + + float q_scale = Q_scale[q_scale_idx]; + float original_sm_scale = sm_scale; + + // wait for Q + wait(&barrier_Q, 0); + wait(&barrier_Q_pe, 0); + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + int p = 1; + for (uint32_t iter = 1; iter < num_iterations; iter++) + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (iter - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // wait for K + wait(&barrier_K, p); + wait(&barrier_K_pe, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); + wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); // add one line +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + if (k_it < num_tiles_qk_pe_inner) { + wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line + } + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // load K + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_K * head_dim_pe) * sizeof(int8_t)>(&barrier_K_pe); // add one line + load_async_4D(sK, &tensorMapK, &barrier_K, 0, iter * CTA_K, kv_head_id, batch_id); + load_async_4D(sK_pe, &tensorMapK_pe, &barrier_K_pe, 0, iter * CTA_K, kv_head_id, batch_id); // add one line + } + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]); // add one line + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + + // load V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sV, &tensorMapV, &barrier_V, iter * CTA_K, 0, kv_head_id, batch_id); + } + } + + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale; + + // wait for K + wait(&barrier_K, p); + wait(&barrier_K_pe, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + int8_t *sQ_local_pe = sQ_pe + fq * 64 * head_dim_pe; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); + wgmma::wgmma_s8s8s32(RS_pe[fq], sQ_local_pe, sK_pe); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + if (k_it < num_tiles_qk_pe_inner) { + wgmma::wgmma_s8s8s32(RS_pe[fq], &sQ_local_pe[k_it*32], &sK_pe[k_it*32]); // add one line + } + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k] + RS_pe[fq][fk][k]) * dequant_scale; + } + } + } + + // masking +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 64 + 8 * ((k % 4) / 2); + const uint32_t k_idx = (num_iterations - 1) * CTA_K + fk * 16 + 2 * (lane_id % 4) + 8 * (k / 4) + k % 2; + + bool is_out_of_bounds; + + if constexpr (mask_mode == MaskMode::kCausal) + { + is_out_of_bounds = (k_idx > q_idx) || (k_idx >= kv_len); + } + else + { + is_out_of_bounds = (k_idx >= kv_len); + } + + if (is_out_of_bounds) + { + RS_f32[fq][fk][k] = -5000000.0f; + } + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + } + + normalize_d(RO, m, d); + + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; + #pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); + + #pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + warp_idx * 16 + (lane_id / 4)) * stride_seq_o + (lane_id % 4) * 2 ; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < head_dim/16; fv++) + { + if (Q_idx_lane_base + fq * 64 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[0]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[2]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[0]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[2]); + } + } + + if (Q_idx_lane_base + fq * 64 + 8 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[1]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[3]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[1]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[3]); + } + } + } + } +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& query_pe, + paddle::Tensor& key_pe, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(query_pe); + CHECK_CUDA(key_pe); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(query_pe); + CHECK_LASTDIM_CONTIGUOUS(key_pe); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(query_pe, paddle::DataType::INT8); + CHECK_DTYPE(key_pe, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(query_pe, 4); + CHECK_DIMS(key_pe, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; // 现在query是正常的128, 多出来的64在query_pe里面,所以这样做没什么问题 + + int stride_bz_q = query.strides()[0]; + int stride_bz_q_pe = query_pe.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_k_pe = key_pe.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, + stride_seq_q_pe, stride_h_q_pe, stride_seq_k_pe, stride_h_k_pe, + stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_q_pe = query_pe.strides()[1]; + stride_h_q_pe = query_pe.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_seq_k_pe = key_pe.strides()[1]; + stride_h_k_pe = key_pe.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_q_pe = query_pe.strides()[2]; + stride_h_q_pe = query_pe.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_seq_k_pe = key_pe.strides()[2]; + stride_h_k_pe = key_pe.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({0}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_type = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_type, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + constexpr int HEAD_DIM_PE = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q_pe, stride_h_q_pe, stride_seq_q_pe); + CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k_pe, stride_h_k_pe, stride_seq_k_pe); + + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + const size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); + auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, false>; + + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_Q_pe, + tma_map_K_pe, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + nullptr, + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector query_pe_shape, + std::vector key_pe_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype, + paddle::DataType H_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90) + .Inputs({"query", "key", "query_pe", "key_pe", "value", "output", "query_scale", "key_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90_InferDtype)); + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_fwd( + paddle::Tensor& query, + paddle::Tensor& key, + paddle::Tensor& query_pe, + paddle::Tensor& key_pe, + paddle::Tensor& value, + paddle::Tensor& output, + paddle::Tensor& query_scale, + paddle::Tensor& key_scale, + paddle::Tensor& value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(query_pe); + CHECK_CUDA(key_pe); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(query_pe); + CHECK_LASTDIM_CONTIGUOUS(key_pe); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, paddle::DataType::INT8); + CHECK_DTYPE(key, paddle::DataType::INT8); + CHECK_DTYPE(query_pe, paddle::DataType::INT8); + CHECK_DTYPE(key_pe, paddle::DataType::INT8); + CHECK_DTYPE(value, paddle::DataType::FLOAT8_E4M3FN); + CHECK_DTYPE(query_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(key_scale, paddle::DataType::FLOAT32); + CHECK_DTYPE(value_scale, paddle::DataType::FLOAT32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(query_pe, 4); + CHECK_DIMS(key_pe, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.shape()[0]; + const int head_dim = query.shape()[3]; + + int stride_bz_q = query.strides()[0]; + int stride_bz_q_pe = query_pe.strides()[0]; + int stride_bz_k = key.strides()[0]; + int stride_bz_k_pe = key_pe.strides()[0]; + int stride_bz_v = value.strides()[0]; + int stride_bz_o = output.strides()[0]; + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, + stride_seq_q_pe, stride_h_q_pe, stride_seq_k_pe, stride_h_k_pe, + stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.shape()[0] == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.shape()[1]; + kv_len = key.shape()[1]; + num_qo_heads = query.shape()[2]; + num_kv_heads = key.shape()[2]; + + stride_seq_q = query.strides()[1]; + stride_h_q = query.strides()[2]; + stride_seq_q_pe = query_pe.strides()[1]; + stride_h_q_pe = query_pe.strides()[2]; + stride_seq_k = key.strides()[1]; + stride_h_k = key.strides()[2]; + stride_seq_k_pe = key_pe.strides()[1]; + stride_h_k_pe = key_pe.strides()[2]; + stride_h_v = value.strides()[2]; + stride_d_v = value.strides()[1]; + stride_seq_o = output.strides()[1]; + stride_h_o = output.strides()[2]; + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + + assert(value.shape()[1] == head_dim); + assert(value.shape()[2] == num_kv_heads); + } + else + { + qo_len = query.shape()[2]; + kv_len = key.shape()[2]; + num_qo_heads = query.shape()[1]; + num_kv_heads = key.shape()[1]; + + stride_seq_q = query.strides()[2]; + stride_h_q = query.strides()[1]; + stride_seq_q_pe = query_pe.strides()[2]; + stride_h_q_pe = query_pe.strides()[1]; + stride_seq_k = key.strides()[2]; + stride_h_k = key.strides()[1]; + stride_seq_k_pe = key_pe.strides()[2]; + stride_h_k_pe = key_pe.strides()[1]; + stride_h_v = value.strides()[1]; + stride_d_v = value.strides()[2]; + stride_seq_o = output.strides()[2]; + stride_h_o = output.strides()[1]; + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.shape()[2] == head_dim); + assert(value.shape()[1] == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + paddle::Tensor lse = paddle::empty({1}, paddle::DataType::FLOAT32); + if (return_lse) + { + lse = paddle::empty({batch_size, num_qo_heads, qo_len}, paddle::DataType::FLOAT32); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.dtype(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_PADDLE_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + constexpr int HEAD_DIM_PE = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.shape()[3] >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, HEAD_DIM); + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_Q_pe = create_tensor_map_4D(reinterpret_cast(query_pe.data()), batch_size, num_qo_heads, qo_len, HEAD_DIM_PE, stride_bz_q_pe, stride_h_q_pe, stride_seq_q_pe); + CUtensorMap tma_map_K_pe = create_tensor_map_4D(reinterpret_cast(key_pe.data()), batch_size, num_kv_heads, kv_len, HEAD_DIM_PE, stride_bz_k_pe, stride_h_k_pe, stride_seq_k_pe); + + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data()), batch_size, num_kv_heads, HEAD_DIM, value.shape()[3], stride_bz_v, stride_h_v, stride_d_v); + + const size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_Q * HEAD_DIM_PE * sizeof(int8_t) + CTA_K * HEAD_DIM_PE * sizeof(int8_t); + auto* kernel = qk_int8_sv_f8_attn_dsk_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, true>; + + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_Q_pe, + tma_map_K_pe, + tma_map_V, + reinterpret_cast(query_scale.data()), + reinterpret_cast(key_scale.data()), + reinterpret_cast(value_scale.data()), + reinterpret_cast(output.data()), + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + + return {lse}; +} + +std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_InferShape( + std::vector query_shape, + std::vector key_shape, + std::vector query_pe_shape, + std::vector key_pe_shape, + std::vector value_shape, + std::vector output_shape, + std::vector query_scale_shape, + std::vector key_scale_shape, + std::vector value_scale_shape) { + + // force layout: NHD: [bsz, seq_len, num_heads, head_dim] + int64_t bsz = query_shape[0]; + int64_t seq_len = query_shape[1]; + int64_t h_qo = query_shape[2]; + + std::vector return_shape = {bsz, h_qo, seq_len}; + return {return_shape}; +} + +std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_InferDtype( + paddle::DataType A_dtype, + paddle::DataType B_dtype, + paddle::DataType C_dtype, + paddle::DataType D_dtype, + paddle::DataType E_dtype, + paddle::DataType F_dtype, + paddle::DataType G_dtype, + paddle::DataType H_dtype, + paddle::DataType I_dtype) { + return {paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90) + .Inputs({"query", "key", "query_pe", "key_pe", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Outputs({"out", "lse"}) + .SetInplaceMap({{"output", "out"}}) // Inplace + .Attrs({"tensor_layout: int", + "is_causal: int", + "qk_quant_gran: int", + "sm_scale: float", + "return_lse: int"}) + .SetKernelFn(PD_KERNEL(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_fwd)) + .SetInferShapeFn(PD_INFER_SHAPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90_InferDtype)); \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 20983c909250..296482b8a3b1 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -170,20 +170,21 @@ def get_gencode_flags(): "--threads=8", "-D_GLIBCXX_USE_CXX11_ABI=1", ] - if cc >= 80: + sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"] + if cc >= 80 and cc < 89: sources += [ - "./gpu/sage_attn_kernels/sageattn_fused.cu", "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu" ] nvcc_compile_args += ["-gencode", f"arch=compute_80,code=compute_80"] - if cc >= 89: + elif cc >= 89 and cc < 90: sources += [ "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu" ] nvcc_compile_args += ["-gencode", f"arch=compute_89,code=compute_89"] - if cc >= 90: + elif cc >= 90: sources += [ - "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu" + "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu", + "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu" ] nvcc_compile_args += ["-gencode", f"arch=compute_90a,code=compute_90a"] From c460ccebfe58d7d260bac0bb7501bd80d31c3df0 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Thu, 27 Feb 2025 23:34:15 +0800 Subject: [PATCH 07/18] clean PR branch --- csrc/gpu/sage_attn_kernels/quant.py | 130 ------------------ .../transformers/fused_transformer_layers.py | 48 +++++-- .../transformers/sageattention.py | 128 ++++++++++++++++- 3 files changed, 161 insertions(+), 145 deletions(-) delete mode 100644 csrc/gpu/sage_attn_kernels/quant.py rename csrc/gpu/sage_attn_kernels/core.py => paddlenlp/experimental/transformers/sageattention.py (76%) diff --git a/csrc/gpu/sage_attn_kernels/quant.py b/csrc/gpu/sage_attn_kernels/quant.py deleted file mode 100644 index 9124810bc44b..000000000000 --- a/csrc/gpu/sage_attn_kernels/quant.py +++ /dev/null @@ -1,130 +0,0 @@ -import paddle -import paddlenlp_ops - -from typing import Optional - -def per_block_int8( - q: paddle.Tensor, - k: paddle.Tensor, - km: Optional[paddle.Tensor] = None, - BLKQ: int =128, - BLKK: int =64, - sm_scale: Optional[float] = None, - tensor_layout: str ="HND" -): - q_int8 = paddle.empty(q.shape, dtype=paddle.int8) - k_int8 = paddle.empty(k.shape, dtype=paddle.int8) - - if tensor_layout == "HND": - b, h_qo, qo_len, head_dim = q.shape - _, h_kv, kv_len, _ = k.shape - elif tensor_layout == "NHD": - b, qo_len, h_qo, head_dim = q.shape - _, kv_len, h_kv, _ = k.shape - else: - raise ValueError(f"Unknown tensor layout: {tensor_layout}") - - _tensor_layout = 0 if tensor_layout == "NHD" else 1 - - q_scale = paddle.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), dtype=paddle.float32) - k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) - - if sm_scale is None: - sm_scale = head_dim**-0.5 - - sm_scale *= 1.44269504 - - paddlenlp_ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) - if km is not None: - km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) - paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) - else: - paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) - - return q_int8, q_scale, k_int8, k_scale - - -def per_warp_int8( - q: paddle.Tensor, - k: paddle.Tensor, - km: Optional[paddle.Tensor] = None, - BLKQ: int =128, - WARPQ: int =32, - BLKK: int =64, - tensor_layout: str ="HND" -): - q_int8 = paddle.empty(shape=q.shape, dtype=paddle.int8) - k_int8 = paddle.empty(shape=k.shape, dtype=paddle.int8) - - if tensor_layout == "HND": - b, h_qo, qo_len, head_dim = q.shape - _, h_kv, kv_len, _ = k.shape - - elif tensor_layout == "NHD": - b, qo_len, h_qo, head_dim = q.shape - _, kv_len, h_kv, _ = k.shape - - else: - raise ValueError(f"Unknown tensor layout: {tensor_layout}") - - _tensor_layout = 0 if tensor_layout == "NHD" else 1 - - q_scale = paddle.empty((b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), dtype=paddle.float32) - k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) - - paddlenlp_ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) - - if km is not None: - km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) - paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) - else: - paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) - - return q_int8, q_scale, k_int8, k_scale - - -def per_channel_fp8( - v: paddle.Tensor, - tensor_layout: str ="NHD", - scale_max: float = 448.0, - smooth_v: bool = True -): - _tensor_layout = 0 if tensor_layout == "NHD" else 1 - - if tensor_layout == "HND": - b, h_kv, kv_len, head_dim = v.shape - padded_len = (kv_len + 63) // 64 * 64 - v_transposed_permutted = paddle.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype) - - elif tensor_layout == "NHD": - b, kv_len, h_kv, head_dim = v.shape - padded_len = (kv_len + 63) // 64 * 64 - v_transposed_permutted = paddle.empty((b, head_dim, h_kv, padded_len), dtype=v.dtype) - paddlenlp_ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) - - v_fp8 = paddle.empty(v_transposed_permutted.shape, dtype=paddle.float8_e4m3fn) - - v_scale = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) - vm = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) - - if smooth_v: - paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, kv_len, scale_max, _tensor_layout) - return v_fp8, v_scale, vm - else: - paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout) - return v_fp8, v_scale, None - - -def sub_mean( - v: paddle.Tensor, - tensor_layout: str ="HND" -): - _tensor_layout = 0 if tensor_layout == "NHD" else 1 - vm = v.mean(dim=1 if _tensor_layout == 0 else 2) - - v_smoothed = paddle.empty(v.shape, dtype=paddle.float16) - - # subtract mean and store the result as fp16 - paddlenlp_ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) - - return v_smoothed, vm \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 591964d7675d..7a244195598d 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -71,6 +71,9 @@ def use_cutlass_fp8_gemm(): transpose_remove_padding, write_cache_kv, ) + from .sageattention import ( + sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 + ) except: pass @@ -2973,6 +2976,8 @@ def compute_mla_absorb( ): from paddlenlp_ops import decode_mla_write_cache, multi_head_latent_attention + use_sageattn = False if os.getenv("USE_SAGEATTN", "0") == "1" else True + ln_out = qkv_out latent_cache = caches[i] @@ -2981,18 +2986,37 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) - fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( - query, - key, - value, - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("max_enc_len_this_time", -1), - kwargs.get("max_enc_len_this_time", -1), - self.softmax_scale, - causal=True, - training=False, - )[0] + if not use_sageattn: + fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( + query, + key, + value, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("max_enc_len_this_time", -1), + kwargs.get("max_enc_len_this_time", -1), + self.softmax_scale, + causal=True, + training=False, + )[0] + else: + query_192 = paddle.unsqueeze(query, axis=0) + key_192 = paddle.unsqueeze(key, axis=0) + + value_128, _ = paddle.split(value, [128, 64], axis=-1) + value_128 = paddle.unsqueeze(value_128, axis=0) + + tensor_layout = "NHD" + fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( + query_192, + key_192, + value_128, + is_causal=True, + sm_scale=self.softmax_scale, + tensor_layout=tensor_layout, + ) + fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128)) + fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0) fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim] diff --git a/csrc/gpu/sage_attn_kernels/core.py b/paddlenlp/experimental/transformers/sageattention.py similarity index 76% rename from csrc/gpu/sage_attn_kernels/core.py rename to paddlenlp/experimental/transformers/sageattention.py index d13557b9aa73..fad59a1729ae 100644 --- a/csrc/gpu/sage_attn_kernels/core.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -4,9 +4,131 @@ from typing import Optional, Any import warnings -from .quant import per_channel_fp8 -from .quant import per_warp_int8 as per_warp_int8_cuda -from .quant import sub_mean +def per_block_int8( + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int =128, + BLKK: int =64, + sm_scale: Optional[float] = None, + tensor_layout: str ="HND" +): + q_int8 = paddle.empty(q.shape, dtype=paddle.int8) + k_int8 = paddle.empty(k.shape, dtype=paddle.int8) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = paddle.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), dtype=paddle.float32) + k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + paddlenlp_ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) + else: + paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8_cuda( + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int =128, + WARPQ: int =32, + BLKK: int =64, + tensor_layout: str ="HND" +): + q_int8 = paddle.empty(shape=q.shape, dtype=paddle.int8) + k_int8 = paddle.empty(shape=k.shape, dtype=paddle.int8) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = paddle.empty((b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), dtype=paddle.float32) + k_scale = paddle.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), dtype=paddle.float32) + + paddlenlp_ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) + else: + paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_channel_fp8( + v: paddle.Tensor, + tensor_layout: str ="NHD", + scale_max: float = 448.0, + smooth_v: bool = True +): + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = paddle.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = paddle.empty((b, head_dim, h_kv, padded_len), dtype=v.dtype) + paddlenlp_ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = paddle.empty(v_transposed_permutted.shape, dtype=paddle.float8_e4m3fn) + + v_scale = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) + vm = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) + + if smooth_v: + paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, kv_len, scale_max, _tensor_layout) + return v_fp8, v_scale, vm + else: + paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout) + return v_fp8, v_scale, None + + +def sub_mean( + v: paddle.Tensor, + tensor_layout: str ="HND" +): + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = paddle.empty(v.shape, dtype=paddle.float16) + + # subtract mean and store the result as fp16 + paddlenlp_ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm def sageattn_qk_int8_pv_fp16_cuda_sm80( From 452a9de810988d4a0619dba3529ca00107afcab5 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Fri, 28 Feb 2025 14:35:01 +0800 Subject: [PATCH 08/18] fix sa usage --- paddlenlp/experimental/transformers/fused_transformer_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 7a244195598d..7c2fc040f3fb 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -2976,7 +2976,7 @@ def compute_mla_absorb( ): from paddlenlp_ops import decode_mla_write_cache, multi_head_latent_attention - use_sageattn = False if os.getenv("USE_SAGEATTN", "0") == "1" else True + use_sageattn = False if os.getenv("USE_SAGEATTN", 0) == 1 else True ln_out = qkv_out latent_cache = caches[i] From ec505b6bfc56188d1f6899d7a1ed683f20b61921 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Fri, 28 Feb 2025 19:37:03 +0800 Subject: [PATCH 09/18] bugfix --- .../experimental/transformers/fused_transformer_layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 7c2fc040f3fb..b067f0aa3e81 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -2976,8 +2976,7 @@ def compute_mla_absorb( ): from paddlenlp_ops import decode_mla_write_cache, multi_head_latent_attention - use_sageattn = False if os.getenv("USE_SAGEATTN", 0) == 1 else True - + use_sageattn = False if os.getenv("USE_SAGEATTN", "0") == "0" else True ln_out = qkv_out latent_cache = caches[i] From 17a6bd8f5ad93f74ccb1ea5fb1bd3acae9354252 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Sat, 1 Mar 2025 23:00:57 +0800 Subject: [PATCH 10/18] modify, for static mode inference SA --- csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 28 +++++++++++-------- .../transformers/sageattention.py | 6 ++-- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 84d25794bb36..6dd9a64bb6b9 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -817,7 +817,7 @@ void scale_fuse_quant_cuda_fwd( paddle::Tensor& input, paddle::Tensor& output, paddle::Tensor& scale, - int num_tokens, + paddle::Tensor& v, // for static graph mode float scale_max, int tensor_layout) { @@ -842,13 +842,14 @@ void scale_fuse_quant_cuda_fwd( int stride_bz_input = input.strides()[0]; int stride_bz_output = output.strides()[0]; - int num_heads, head_dim; + int num_heads, head_dim, num_tokens; int stride_d_input, stride_h_input, stride_d_output, stride_h_output; if (tensor_layout == 0) { num_heads = input.shape()[2]; head_dim = input.shape()[1]; + num_tokens = v.shape()[1]; stride_d_input = input.strides()[1]; stride_h_input = input.strides()[2]; stride_d_output = output.strides()[1]; @@ -858,6 +859,7 @@ void scale_fuse_quant_cuda_fwd( { num_heads = input.shape()[1]; head_dim = input.shape()[2]; + num_tokens = v.shape()[2]; stride_d_input = input.strides()[2]; stride_h_input = input.strides()[1]; stride_d_output = output.strides()[2]; @@ -891,10 +893,10 @@ void scale_fuse_quant_cuda_fwd( } PD_BUILD_OP(scale_fuse_quant_cuda) - .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace - .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .Inputs({"input", "output", "scale", "v"}) + .Outputs({"out1", "out2", "out3", "out4"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}, {"v", "out4"}}) // Inplace + .Attrs({"scale_max: float", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(scale_fuse_quant_cuda_fwd)); // smooth v @@ -903,7 +905,7 @@ void mean_scale_fuse_quant_cuda_fwd( paddle::Tensor& output, paddle::Tensor& mean, paddle::Tensor& scale, - int num_tokens, + paddle::Tensor& v, // for static graph mode float scale_max, int tensor_layout) { @@ -932,13 +934,14 @@ void mean_scale_fuse_quant_cuda_fwd( int stride_bz_input = input.strides()[0]; int stride_bz_output = output.strides()[0]; - int num_heads, head_dim; + int num_heads, head_dim, num_tokens; int stride_d_input, stride_h_input, stride_d_output, stride_h_output; if (tensor_layout == 0) { num_heads = input.shape()[2]; head_dim = input.shape()[1]; + num_tokens = v.shape()[1]; stride_d_input = input.strides()[1]; stride_h_input = input.strides()[2]; stride_d_output = output.strides()[1]; @@ -948,6 +951,7 @@ void mean_scale_fuse_quant_cuda_fwd( { num_heads = input.shape()[1]; head_dim = input.shape()[2]; + num_tokens = v.shape()[2]; stride_d_input = input.strides()[2]; stride_h_input = input.strides()[1]; stride_d_output = output.strides()[2]; @@ -982,8 +986,8 @@ void mean_scale_fuse_quant_cuda_fwd( } PD_BUILD_OP(mean_scale_fuse_quant_cuda) - .Inputs({"input", "output", "mean", "scale"}) - .Outputs({"out1", "out2", "out3", "out4"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}}) // Inplace - .Attrs({"num_tokens: int", "scale_max: float", "tensor_layout: int"}) + .Inputs({"input", "output", "mean", "scale", "v"}) + .Outputs({"out1", "out2", "out3", "out4", "out5"}) + .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}, {"v", "out5"}}) // Inplace + .Attrs({"scale_max: float", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(mean_scale_fuse_quant_cuda_fwd)); \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/sageattention.py b/paddlenlp/experimental/transformers/sageattention.py index fad59a1729ae..6619dedaca2f 100644 --- a/paddlenlp/experimental/transformers/sageattention.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -109,10 +109,10 @@ def per_channel_fp8( vm = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) if smooth_v: - paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, kv_len, scale_max, _tensor_layout) + paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, v, scale_max, _tensor_layout) # modified: use `v` instead of kv_len for static mode return v_fp8, v_scale, vm else: - paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout) + paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, v, scale_max, _tensor_layout) # modified: use `v` instead of kv_len for static mode return v_fp8, v_scale, None @@ -413,7 +413,7 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( elif head_dim_og > pad_dim_tgt: raise ValueError(f"Unsupported head_dim: {head_dim_og}") - assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." + # assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 From 7035e1ef75d34c47a4efa8e4da575394ffe49811 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 4 Mar 2025 18:49:22 +0800 Subject: [PATCH 11/18] add license info --- csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 13 ++++++ .../sageattn_qk_int_sv_f16_kernel_sm80.cu | 13 ++++++ .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 13 ++++++ .../sageattn_qk_int_sv_f8_kernel_sm89.cu | 13 ++++++ .../sageattn_qk_int_sv_f8_kernel_sm90.cu | 13 ++++++ csrc/gpu/sage_attn_kernels/sageattn_utils.cuh | 13 ++++++ .../transformers/fused_transformer_layers.py | 41 ++++++++++--------- 7 files changed, 100 insertions(+), 19 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 6dd9a64bb6b9..46aabc08d5f0 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu index cb666b9a7a35..081758cff8b5 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu index 7d16f911396a..e9e69e67e473 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu index b5853899b574..4e9c2144a52e 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu index 0ee91f10a247..7ce1162691c7 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh index 3e408208b0d1..4ce282c91383 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh +++ b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh @@ -1,3 +1,16 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index b067f0aa3e81..52badec58ff5 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -71,9 +71,8 @@ def use_cutlass_fp8_gemm(): transpose_remove_padding, write_cache_kv, ) - from .sageattention import ( - sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 - ) + + from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 except: pass @@ -2984,38 +2983,42 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) + seq_len_q_slices = kwargs.get("cu_seqlens_q", None) + seq_len_k_slices = kwargs.get("cu_seqlens_k", None) + bsz = 1 + if seq_len_q_slices is not None: + bsz = seq_len_q_slices.shape[0] - 1 - if not use_sageattn: - fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( - query, - key, - value, - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("max_enc_len_this_time", -1), - kwargs.get("max_enc_len_this_time", -1), - self.softmax_scale, - causal=True, - training=False, - )[0] - else: + if use_sageattn and bsz == 1: # batch size == 1 query_192 = paddle.unsqueeze(query, axis=0) key_192 = paddle.unsqueeze(key, axis=0) value_128, _ = paddle.split(value, [128, 64], axis=-1) value_128 = paddle.unsqueeze(value_128, axis=0) - tensor_layout = "NHD" fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( query_192, key_192, value_128, is_causal=True, sm_scale=self.softmax_scale, - tensor_layout=tensor_layout, + tensor_layout="NHD", ) fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128)) fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0) + else: + fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( + query, + key, + value, + seq_len_q_slices, + seq_len_k_slices, + kwargs.get("max_enc_len_this_time", -1), + kwargs.get("max_enc_len_this_time", -1), + self.softmax_scale, + causal=True, + training=False, + )[0] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim] From 38ea097e78afb5f8a613691a5ba72248ffb66d24 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 4 Mar 2025 20:04:53 +0800 Subject: [PATCH 12/18] add license info for py file --- .../transformers/sageattention.py | 309 +++++++++++++----- 1 file changed, 228 insertions(+), 81 deletions(-) diff --git a/paddlenlp/experimental/transformers/sageattention.py b/paddlenlp/experimental/transformers/sageattention.py index 6619dedaca2f..f47fe878d716 100644 --- a/paddlenlp/experimental/transformers/sageattention.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -1,17 +1,31 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Any, Optional + import paddle import paddlenlp_ops -from typing import Optional, Any -import warnings def per_block_int8( - q: paddle.Tensor, - k: paddle.Tensor, - km: Optional[paddle.Tensor] = None, - BLKQ: int =128, - BLKK: int =64, - sm_scale: Optional[float] = None, - tensor_layout: str ="HND" + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", ): q_int8 = paddle.empty(q.shape, dtype=paddle.int8) k_int8 = paddle.empty(k.shape, dtype=paddle.int8) @@ -32,7 +46,7 @@ def per_block_int8( if sm_scale is None: sm_scale = head_dim**-0.5 - + sm_scale *= 1.44269504 paddlenlp_ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) @@ -46,13 +60,13 @@ def per_block_int8( def per_warp_int8_cuda( - q: paddle.Tensor, - k: paddle.Tensor, - km: Optional[paddle.Tensor] = None, - BLKQ: int =128, - WARPQ: int =32, - BLKK: int =64, - tensor_layout: str ="HND" + q: paddle.Tensor, + k: paddle.Tensor, + km: Optional[paddle.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", ): q_int8 = paddle.empty(shape=q.shape, dtype=paddle.int8) k_int8 = paddle.empty(shape=k.shape, dtype=paddle.int8) @@ -67,7 +81,7 @@ def per_warp_int8_cuda( else: raise ValueError(f"Unknown tensor layout: {tensor_layout}") - + _tensor_layout = 0 if tensor_layout == "NHD" else 1 q_scale = paddle.empty((b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), dtype=paddle.float32) @@ -80,16 +94,11 @@ def per_warp_int8_cuda( paddlenlp_ops.quant_per_block_int8_fuse_sub_mean_cuda(k, km, k_int8, k_scale, BLKK, _tensor_layout) else: paddlenlp_ops.quant_per_block_int8_cuda(k, k_int8, k_scale, BLKK, _tensor_layout) - + return q_int8, q_scale, k_int8, k_scale -def per_channel_fp8( - v: paddle.Tensor, - tensor_layout: str ="NHD", - scale_max: float = 448.0, - smooth_v: bool = True -): +def per_channel_fp8(v: paddle.Tensor, tensor_layout: str = "NHD", scale_max: float = 448.0, smooth_v: bool = True): _tensor_layout = 0 if tensor_layout == "NHD" else 1 if tensor_layout == "HND": @@ -109,22 +118,23 @@ def per_channel_fp8( vm = paddle.empty((b, h_kv, head_dim), dtype=paddle.float32) if smooth_v: - paddlenlp_ops.mean_scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, vm, v_scale, v, scale_max, _tensor_layout) # modified: use `v` instead of kv_len for static mode + paddlenlp_ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, vm, v_scale, v, scale_max, _tensor_layout + ) # modified: use `v` instead of kv_len for static mode return v_fp8, v_scale, vm else: - paddlenlp_ops.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, v, scale_max, _tensor_layout) # modified: use `v` instead of kv_len for static mode + paddlenlp_ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, v, scale_max, _tensor_layout + ) # modified: use `v` instead of kv_len for static mode return v_fp8, v_scale, None - -def sub_mean( - v: paddle.Tensor, - tensor_layout: str ="HND" -): + +def sub_mean(v: paddle.Tensor, tensor_layout: str = "HND"): _tensor_layout = 0 if tensor_layout == "NHD" else 1 vm = v.mean(dim=1 if _tensor_layout == 0 else 2) v_smoothed = paddle.empty(v.shape, dtype=paddle.float16) - + # subtract mean and store the result as fp16 paddlenlp_ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) @@ -132,8 +142,8 @@ def sub_mean( def sageattn_qk_int8_pv_fp16_cuda_sm80( - q: paddle.Tensor, - k: paddle.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, v: paddle.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -145,7 +155,10 @@ def sageattn_qk_int8_pv_fp16_cuda_sm80( return_lse: bool = False, ) -> paddle.Tensor: dtype = q.dtype - assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert dtype in [ + paddle.float16, + paddle.bfloat16, + ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." @@ -179,14 +192,24 @@ def sageattn_qk_int8_pv_fp16_cuda_sm80( km = paddle.mean(k, axis=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": - lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + lse_correction = paddle.squeeze( + paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1 + ) else: lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) else: km = None if qk_quant_gran == "per_warp": - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.shape[-1] == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64) + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.shape[-1] == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) o = paddle.empty(q.shape, dtype=dtype) @@ -194,19 +217,48 @@ def sageattn_qk_int8_pv_fp16_cuda_sm80( warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") smooth_v = False - if pv_accum_dtype == 'fp32': + if pv_accum_dtype == "fp32": v = v.to(paddle.float16) - lse = paddlenlp_ops.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f32_attn( + q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse + ) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) - lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) else: v = v.to(paddle.float16) - lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) elif pv_accum_dtype == "fp16+fp32": v = v.to(paddle.float16) - lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse + ) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") @@ -217,9 +269,10 @@ def sageattn_qk_int8_pv_fp16_cuda_sm80( else: return o + def sageattn_qk_int8_pv_fp8_cuda_sm89( - q: paddle.Tensor, - k: paddle.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, v: paddle.Tensor, tensor_layout: str = "NHD", is_causal: bool = False, @@ -231,16 +284,19 @@ def sageattn_qk_int8_pv_fp8_cuda_sm89( return_lse: bool = False, ): dtype = q.dtype - assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert dtype in [ + paddle.float16, + paddle.bfloat16, + ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_causal = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 _return_lse = 1 if return_lse else 0 - + head_dim_og = q.shape[-1] - + if head_dim_og < 64: q = paddle.nn.functional.pad(q, (0, 64 - head_dim_og)) k = paddle.nn.functional.pad(k, (0, 64 - head_dim_og)) @@ -251,40 +307,82 @@ def sageattn_qk_int8_pv_fp8_cuda_sm89( v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") - + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 - + seq_dim = 1 if _tensor_layout == 0 else 2 - + if smooth_k: km = paddle.mean(k, axis=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": - lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + lse_correction = paddle.squeeze( + paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1 + ) else: lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) else: km = None - + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout) o = paddle.empty(q.shape, dtype=dtype) - if pv_accum_dtype == 'fp32+fp32' and smooth_v: + if pv_accum_dtype == "fp32+fp32" and smooth_v: warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") smooth_v = False - + v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v) if pv_accum_dtype == "fp32": if smooth_v: - lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_causal, + _qk_quant_gran, + sm_scale, + _return_lse, + ) else: - lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_causal, + _qk_quant_gran, + sm_scale, + _return_lse, + ) elif pv_accum_dtype == "fp32+fp32": - lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm89( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_causal, + _qk_quant_gran, + sm_scale, + _return_lse, + ) o = o[..., :head_dim_og] @@ -292,11 +390,11 @@ def sageattn_qk_int8_pv_fp8_cuda_sm89( return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: return o - + def sageattn_qk_int8_pv_fp8_cuda_sm90( - q: paddle.Tensor, - k: paddle.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, v: paddle.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -308,7 +406,10 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( **kwargs: Any, ) -> paddle.Tensor: dtype = q.dtype - assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert dtype in [ + paddle.float16, + paddle.bfloat16, + ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." @@ -329,26 +430,30 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( v = paddle.nn.functional.pad(v, (0, 128 - head_dim_og)) elif head_dim_og > 128: raise ValueError(f"Unsupported head_dim: {head_dim_og}") - + assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 - + seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = paddle.mean(k, axis=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": - lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + lse_correction = paddle.squeeze( + paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1 + ) else: lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) else: km = None if qk_quant_gran == "per_warp": - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128) + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) o = paddle.empty(q.shape, dtype=dtype) @@ -356,14 +461,30 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 if v_pad_len > 0: if tensor_layout == "HND": - v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2) + v = paddle.concat( + [v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2 + ) else: - v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1) + v = paddle.concat( + [v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1 + ) - v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) - lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_causal, + _qk_quant_gran, + sm_scale, + _return_lse, + ) o = o[..., :head_dim_og] @@ -374,8 +495,8 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( - q: paddle.Tensor, - k: paddle.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, v: paddle.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -386,7 +507,10 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( return_lse: bool = False, ) -> paddle.Tensor: dtype = q.dtype - assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert dtype in [ + paddle.float16, + paddle.bfloat16, + ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." @@ -412,26 +536,30 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( k = paddle.nn.functional.pad(k, (0, pad_dim_tgt - head_dim_og)) elif head_dim_og > pad_dim_tgt: raise ValueError(f"Unsupported head_dim: {head_dim_og}") - + # assert q.strides[-1] == 1 and k.strides[-1] == 1 and v.strides[-1] == 1, "Last dim of qkv must be contiguous." if sm_scale is None: sm_scale = head_dim_og**-0.5 - + seq_dim = 1 if _tensor_layout == 0 else 2 if smooth_k: km = paddle.mean(k, axis=seq_dim, keepdim=True) if return_lse: if tensor_layout == "NHD": - lse_correction = paddle.squeeze(paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1) + lse_correction = paddle.squeeze( + paddle.matmul(paddle.transpose(q, [0, 2, 1, 3], paddle.transpose(km, [0, 2, 3, 1]))), axis=-1 + ) else: lse_correction = paddle.squeeze(paddle.matmul(q, paddle.transpose(km, [0, 1, 3, 2])), axis=-1) else: km = None if qk_quant_gran == "per_warp": - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128) + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) o = paddle.empty(v.shape, dtype=dtype) @@ -439,9 +567,13 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 if v_pad_len > 0: if tensor_layout == "HND": - v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2) - else: - v = paddle.concat([v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1) + v = paddle.concat( + [v, paddle.zeros(shape=[v.shape[0], v.shape[1], v_pad_len, v.shape[3]], dtype=v.dtype)], axis=2 + ) + else: + v = paddle.concat( + [v, paddle.zeros(shape=[v.shape[0], v_pad_len, v.shape[2], v.shape[3]], dtype=v.dtype)], axis=1 + ) v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) if pad_dim_tgt == 256: @@ -451,7 +583,22 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( q_int8_nope, q_int8_pe = paddle.split(q_int8, [128, 64], axis=-1) k_int8_nope, k_int8_pe = paddle.split(k_int8, [128, 64], axis=-1) - lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90(q_int8_nope, k_int8_nope, q_int8_pe, k_int8_pe, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_causal, _qk_quant_gran, sm_scale, _return_lse) + lse = paddlenlp_ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90( + q_int8_nope, + k_int8_nope, + q_int8_pe, + k_int8_pe, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_causal, + _qk_quant_gran, + sm_scale, + _return_lse, + ) head_dim_og = v.shape[-1] o = o[..., :head_dim_og] From c42d1f5c36bea7ce149bac51b5a2b1bbc13b82f0 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 4 Mar 2025 20:27:53 +0800 Subject: [PATCH 13/18] modify license info --- csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 32 +++++++++++-------- .../sageattn_qk_int_sv_f16_kernel_sm80.cu | 31 ++++++++++-------- .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 31 ++++++++++-------- .../sageattn_qk_int_sv_f8_kernel_sm89.cu | 31 ++++++++++-------- .../sageattn_qk_int_sv_f8_kernel_sm90.cu | 31 ++++++++++-------- csrc/gpu/sage_attn_kernels/sageattn_utils.cuh | 31 ++++++++++-------- .../transformers/sageattention.py | 3 ++ 7 files changed, 112 insertions(+), 78 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 46aabc08d5f0..680eed4e9714 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -1,16 +1,22 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include #include diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu index 081758cff8b5..9c8111462e81 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu @@ -1,16 +1,21 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu index e9e69e67e473..cab88256411a 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -1,16 +1,21 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu index 4e9c2144a52e..269ad2973cf8 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu @@ -1,16 +1,21 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu index 7ce1162691c7..ec08510c18eb 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu @@ -1,16 +1,21 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh index 4ce282c91383..81df5a3be21c 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh +++ b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh @@ -1,16 +1,21 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * + * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), + * but has been re-implemented by PaddlePaddle. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include #include diff --git a/paddlenlp/experimental/transformers/sageattention.py b/paddlenlp/experimental/transformers/sageattention.py index f47fe878d716..7c94c11026ec 100644 --- a/paddlenlp/experimental/transformers/sageattention.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -1,5 +1,8 @@ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # +# This file is inspired by the SageAttention project (https://github.com/thu-ml/SageAttention), +# but has been re-implemented by PaddlePaddle. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From 22894d7ee0e7ce14b7d0dca66ba0c7ac84f7c101 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Tue, 4 Mar 2025 21:15:36 +0800 Subject: [PATCH 14/18] modify license info --- csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 25 ++++++++++++++---- .../sageattn_qk_int_sv_f16_kernel_sm80.cu | 26 +++++++++++++++---- .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 26 +++++++++++++++---- .../sageattn_qk_int_sv_f8_kernel_sm89.cu | 26 +++++++++++++++---- .../sageattn_qk_int_sv_f8_kernel_sm90.cu | 26 +++++++++++++++---- csrc/gpu/sage_attn_kernels/sageattn_utils.cuh | 26 +++++++++++++++---- .../transformers/sageattention.py | 22 +++++++++++++--- 7 files changed, 144 insertions(+), 33 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 680eed4e9714..8bdbe9b016d2 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -1,21 +1,36 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/fused/fused.cu. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ #include #include diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu index 9c8111462e81..2089bd1fbc5e 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu @@ -1,21 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/qattn/qk_int_sv_f16_cuda_sm80.cu. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ + #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu index cab88256411a..66869f18e6e1 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -1,21 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ + #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu index 269ad2973cf8..e4aad5c2aed7 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu @@ -1,21 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/qattn/qk_int_sv_f8_cuda_sm89.cu. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ + #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu index ec08510c18eb..3425b532c4b7 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu @@ -1,21 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ + #include #include "paddle/extension.h" diff --git a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh index 81df5a3be21c..2849d707009e 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh +++ b/csrc/gpu/sage_attn_kernels/sageattn_utils.cuh @@ -1,21 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* - * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * This file is adapted from + * https://github.com/thu-ml/SageAttention/blob/main/csrc/*.cuh. + * The original license is kept as-is: * - * This file is inspired by SageAttention project (https://github.com/thu-ml/SageAttention), - * but has been re-implemented by PaddlePaddle. + * Copyright (c) 2024 by SageAttention team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ + #pragma once #include #include diff --git a/paddlenlp/experimental/transformers/sageattention.py b/paddlenlp/experimental/transformers/sageattention.py index 7c94c11026ec..3d63fd86a21a 100644 --- a/paddlenlp/experimental/transformers/sageattention.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -1,7 +1,22 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# This file is inspired by the SageAttention project (https://github.com/thu-ml/SageAttention), -# but has been re-implemented by PaddlePaddle. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from: +# https://github.com/thu-ml/SageAttention/blob/main/sageattention/core.py. +# The original license is kept as-is: +# +# Copyright (c) 2024 by SageAttention team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +29,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import warnings from typing import Any, Optional From 41f2900911db749ca88d3d826677afa765a81e48 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Thu, 6 Mar 2025 12:19:44 +0800 Subject: [PATCH 15/18] bsz=1 assert --- .../sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu | 14 +++++++++++--- .../transformers/fused_transformer_layers.py | 14 ++++++-------- .../experimental/transformers/sageattention.py | 4 ++++ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu index 66869f18e6e1..53ecbc9e078b 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu @@ -855,6 +855,8 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_d paddle::Tensor& key, paddle::Tensor& query_pe, paddle::Tensor& key_pe, + paddle::Tensor& q_seq_indices, + paddle::Tensor& k_seq_indices, paddle::Tensor& value, paddle::Tensor& output, paddle::Tensor& query_scale, @@ -905,7 +907,9 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_d CHECK_DIMS(key_scale, 3); CHECK_DIMS(value_scale, 3); - const int batch_size = query.shape()[0]; + const int batch_size = q_seq_indices.shape()[0] - 1; + PD_CHECK(batch_size == 1, "Sage Attention only support batch_size == 1"); + const int head_dim = query.shape()[3]; int stride_bz_q = query.strides()[0]; @@ -1058,6 +1062,8 @@ std::vector> qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst std::vector key_shape, std::vector query_pe_shape, std::vector key_pe_shape, + std::vector q_seq_indices_shape, + std::vector k_seq_indices_shape, std::vector value_shape, std::vector output_shape, std::vector query_scale_shape, @@ -1082,12 +1088,14 @@ std::vector qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf paddle::DataType F_dtype, paddle::DataType G_dtype, paddle::DataType H_dtype, - paddle::DataType I_dtype) { + paddle::DataType I_dtype, + paddle::DataType J_dtype, + paddle::DataType K_dtype) { return {paddle::DataType::FLOAT32}; } PD_BUILD_OP(qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90) - .Inputs({"query", "key", "query_pe", "key_pe", "value", "output", "query_scale", "key_scale", "value_scale"}) + .Inputs({"query", "key", "query_pe", "key_pe", "q_seq_indices", "k_seq_indices", "value", "output", "query_scale", "key_scale", "value_scale"}) .Outputs({"out", "lse"}) .SetInplaceMap({{"output", "out"}}) // Inplace .Attrs({"tensor_layout: int", diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 52badec58ff5..11a47acb9846 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -2983,13 +2983,9 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) - seq_len_q_slices = kwargs.get("cu_seqlens_q", None) - seq_len_k_slices = kwargs.get("cu_seqlens_k", None) - bsz = 1 - if seq_len_q_slices is not None: - bsz = seq_len_q_slices.shape[0] - 1 - if use_sageattn and bsz == 1: # batch size == 1 + if use_sageattn: + query_192 = paddle.unsqueeze(query, axis=0) key_192 = paddle.unsqueeze(key, axis=0) @@ -2999,6 +2995,8 @@ def compute_mla_absorb( fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( query_192, key_192, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), value_128, is_causal=True, sm_scale=self.softmax_scale, @@ -3011,8 +3009,8 @@ def compute_mla_absorb( query, key, value, - seq_len_q_slices, - seq_len_k_slices, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), kwargs.get("max_enc_len_this_time", -1), kwargs.get("max_enc_len_this_time", -1), self.softmax_scale, diff --git a/paddlenlp/experimental/transformers/sageattention.py b/paddlenlp/experimental/transformers/sageattention.py index 3d63fd86a21a..df77f5c5c7ef 100644 --- a/paddlenlp/experimental/transformers/sageattention.py +++ b/paddlenlp/experimental/transformers/sageattention.py @@ -516,6 +516,8 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( q: paddle.Tensor, k: paddle.Tensor, + q_seq_indices: paddle.Tensor, + k_seq_indices: paddle.Tensor, v: paddle.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -607,6 +609,8 @@ def sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( k_int8_nope, q_int8_pe, k_int8_pe, + q_seq_indices, + k_seq_indices, v_fp8, o, q_scale, From 44350a1352c06b301b7915e3da2f42dbff29753d Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Thu, 6 Mar 2025 12:46:05 +0800 Subject: [PATCH 16/18] fix kernel --- csrc/gpu/sage_attn_kernels/sageattn_fused.cu | 28 +++++++++---------- .../transformers/fused_transformer_layers.py | 5 ++-- paddlenlp/utils/env.py | 1 + 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu index 8bdbe9b016d2..788600d81a92 100644 --- a/csrc/gpu/sage_attn_kernels/sageattn_fused.cu +++ b/csrc/gpu/sage_attn_kernels/sageattn_fused.cu @@ -492,8 +492,8 @@ void quant_per_block_int8_fuse_sub_mean_cuda_fwd( PD_BUILD_OP(quant_per_block_int8_fuse_sub_mean_cuda) .Inputs({"input", "mean", "output", "scale"}) - .Outputs({"out1", "out2", "out3", "out4"}) - .SetInplaceMap({{"input", "out1"}, {"mean", "out2"}, {"output", "out3"}, {"scale", "out4"}}) // Inplace + .Outputs({"out_mean", "out", "out_scale"}) + .SetInplaceMap({{"mean", "out_mean"}, {"output", "out"}, {"scale", "out_scale"}}) // Inplace .Attrs({"block_size: int", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(quant_per_block_int8_fuse_sub_mean_cuda_fwd)); @@ -580,8 +580,8 @@ void quant_per_warp_int8_cuda_fwd( PD_BUILD_OP(quant_per_warp_int8_cuda) .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Outputs({"out", "out_scale"}) + .SetInplaceMap({{"output", "out"}, {"scale", "out_scale"}}) // Inplace .Attrs({"block_size: int", "warp_block_size: int", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(quant_per_warp_int8_cuda_fwd)); @@ -670,8 +670,8 @@ void quant_per_block_int8_cuda_scale_fwd( PD_BUILD_OP(quant_per_block_int8_cuda_scale) .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Outputs({"out", "out_scale"}) + .SetInplaceMap({{"output", "out"}, {"scale", "out_scale"}}) // Inplace .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_scale_fwd)); @@ -759,8 +759,8 @@ void quant_per_block_int8_cuda_fwd( PD_BUILD_OP(quant_per_block_int8_cuda) .Inputs({"input", "output", "scale"}) - .Outputs({"out1", "out2", "out3"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}}) // Inplace + .Outputs({"out", "out_scale"}) + .SetInplaceMap({{"output", "out"}, {"scale", "out_scale"}}) // Inplace .Attrs({"sm_scale: float", "block_size: int", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(quant_per_block_int8_cuda_fwd)); @@ -842,8 +842,8 @@ void transpose_pad_permute_cuda_fwd( PD_BUILD_OP(transpose_pad_permute_cuda) .Inputs({"input", "output"}) - .Outputs({"out1", "out2"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}}) // Inplace + .Outputs({"out"}) + .SetInplaceMap({{"output", "out"}}) // Inplace .Attrs({"tensor_layout: int"}) .SetKernelFn(PD_KERNEL(transpose_pad_permute_cuda_fwd)); @@ -928,8 +928,8 @@ void scale_fuse_quant_cuda_fwd( PD_BUILD_OP(scale_fuse_quant_cuda) .Inputs({"input", "output", "scale", "v"}) - .Outputs({"out1", "out2", "out3", "out4"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"scale", "out3"}, {"v", "out4"}}) // Inplace + .Outputs({"out", "out_scale"}) + .SetInplaceMap({{"output", "out"}, {"scale", "out_scale"}}) // Inplace .Attrs({"scale_max: float", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(scale_fuse_quant_cuda_fwd)); @@ -1021,7 +1021,7 @@ void mean_scale_fuse_quant_cuda_fwd( PD_BUILD_OP(mean_scale_fuse_quant_cuda) .Inputs({"input", "output", "mean", "scale", "v"}) - .Outputs({"out1", "out2", "out3", "out4", "out5"}) - .SetInplaceMap({{"input", "out1"}, {"output", "out2"}, {"mean", "out3"}, {"scale", "out4"}, {"v", "out5"}}) // Inplace + .Outputs({"out", "out_mean", "out_scale"}) + .SetInplaceMap({{"output", "out"}, {"mean", "out_mean"}, {"scale", "out_scale"}}) // Inplace .Attrs({"scale_max: float", "tensor_layout: int"}) .SetKernelFn(PD_KERNEL(mean_scale_fuse_quant_cuda_fwd)); \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 11a47acb9846..e437c79a8c66 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -2975,7 +2975,6 @@ def compute_mla_absorb( ): from paddlenlp_ops import decode_mla_write_cache, multi_head_latent_attention - use_sageattn = False if os.getenv("USE_SAGEATTN", "0") == "0" else True ln_out = qkv_out latent_cache = caches[i] @@ -2984,7 +2983,9 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) - if use_sageattn: + from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN + + if PREFILL_USE_SAGE_ATTN: query_192 = paddle.unsqueeze(query, axis=0) key_192 = paddle.unsqueeze(key, axis=0) diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index ac7396a48828..be212ec295fd 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -148,3 +148,4 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: PADDLE_INFERENCE_WEIGHTS_SUFFIX = ".pdiparams" USE_FAST_TOKENIZER: bool = _get_bool_env("USE_FAST_TOKENIZER", "false") +PREFILL_USE_SAGE_ATTN: bool = _get_bool_env("PREFILL_USE_SAGE_ATTN", "false") From 490db5af5ab3ea1a6ac572e725018145227d3e20 Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Thu, 6 Mar 2025 12:51:02 +0800 Subject: [PATCH 17/18] move to import line --- paddlenlp/experimental/transformers/fused_transformer_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index e437c79a8c66..ae6cdd95fccb 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -72,7 +72,6 @@ def use_cutlass_fp8_gemm(): write_cache_kv, ) - from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 except: pass @@ -2986,6 +2985,7 @@ def compute_mla_absorb( from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN if PREFILL_USE_SAGE_ATTN: + from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 query_192 = paddle.unsqueeze(query, axis=0) key_192 = paddle.unsqueeze(key, axis=0) From 2004a8aa1adcef7a7584e1ec5cc67d18e9f1385a Mon Sep 17 00:00:00 2001 From: l1cacheDell Date: Thu, 6 Mar 2025 13:07:14 +0800 Subject: [PATCH 18/18] merge develop & support wint8&fp8 --- .../transformers/fused_transformer_layers.py | 96 ++++++++++++++----- 1 file changed, 72 insertions(+), 24 deletions(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index de33011f00ab..16a6cf678122 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -3327,18 +3327,42 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) - fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( - query, - key, - value, - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("max_enc_len_this_time", -1), - kwargs.get("max_enc_len_this_time", -1), - self.softmax_scale, - causal=True, - training=False, - )[0] + from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN + + if PREFILL_USE_SAGE_ATTN: + from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 + + query_192 = paddle.unsqueeze(query, axis=0) + key_192 = paddle.unsqueeze(key, axis=0) + + value_128, _ = paddle.split(value, [128, 64], axis=-1) + value_128 = paddle.unsqueeze(value_128, axis=0) + + fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( + query_192, + key_192, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + value_128, + is_causal=True, + sm_scale=self.softmax_scale, + tensor_layout="NHD", + ) + fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128)) + fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0) + else: + fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( + query, + key, + value, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("max_enc_len_this_time", -1), + kwargs.get("max_enc_len_this_time", -1), + self.softmax_scale, + causal=True, + training=False, + )[0] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim] @@ -5022,18 +5046,42 @@ def compute_mla_absorb( if kwargs["max_enc_len_this_time"]: # prefill phase query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs) - fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( - query, - key, - value, - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("max_enc_len_this_time", -1), - kwargs.get("max_enc_len_this_time", -1), - self.softmax_scale, - causal=True, - training=False, - )[0] + from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN + + if PREFILL_USE_SAGE_ATTN: + from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 + + query_192 = paddle.unsqueeze(query, axis=0) + key_192 = paddle.unsqueeze(key, axis=0) + + value_128, _ = paddle.split(value, [128, 64], axis=-1) + value_128 = paddle.unsqueeze(value_128, axis=0) + + fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90( + query_192, + key_192, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + value_128, + is_causal=True, + sm_scale=self.softmax_scale, + tensor_layout="NHD", + ) + fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128)) + fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0) + else: + fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded( + query, + key, + value, + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("max_enc_len_this_time", -1), + kwargs.get("max_enc_len_this_time", -1), + self.softmax_scale, + causal=True, + training=False, + )[0] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]