diff --git a/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu b/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu new file mode 100644 index 000000000000..630cf95c9b4f --- /dev/null +++ b/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu @@ -0,0 +1,301 @@ +#include "quant_utils.h" + +template +struct __align__(sizeof(T) * VecSize) VecType { + T val[VecSize]; + __host__ __device__ inline T& operator[](size_t i) { return val[i]; } + __host__ __device__ inline const T& operator[](size_t i) const { + return val[i]; + } +}; + +template +__device__ void BlockLoad(const phi::bfloat16* X, + __nv_bfloat16 input[4][4], + size_t K) { + for (size_t i = 0; i < 4; i++) { + size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32; + size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; + size_t offset = off_m * K + off_k; + + for (size_t j = 0; j < 4; j += VecSize) { + if (off_k + j * 32 < K) { + size_t idx = offset + j * 32; + using LoadT = VecType<__nv_bfloat16, VecSize>; + LoadT data = *reinterpret_cast(X + idx); + for (int k = 0; k < VecSize; k++) { + input[i][j + k] = data[k]; + } + } + } + } +} + +__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4], + __nv_bfloat16 amax[4], + __nv_bfloat16* shm) { + // Reduce [(4), 32, 32, 4] => [32, 32, 4] + __nv_bfloat16 warp_max[4]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + __nv_bfloat16 t = __habs(input[i][j]); + warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t); + } + } + + // Reduce [(32), 32, 4] => [32, 4] + for (int i = 0; i < 4; i++) { + shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i]; + } + __syncthreads(); + for (int offset = 16; offset > 0; offset /= 2) { + if (threadIdx.y < offset) { + for (int i = 0; i < 4; i++) { + shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = + __hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x], + shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]); + } + } + __syncthreads(); + } + + for (int i = 0; i < 4; i++) { + amax[i] = shm[i * 32 + threadIdx.x]; + } + __syncthreads(); +} + +template +__device__ void BlockStoreScale(float* scale, + size_t off_m, + __nv_bfloat16 amax[4], + float scale_inv[4], + size_t K) { + float scale_out[4]; + for (int i = 0; i < 4; i++) { + scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>( + static_cast(amax[i]), 0.0f); + scale_out[i] = __frcp_rn(scale_inv[i]); + } + if (threadIdx.y == 0) { + size_t idx_m = blockIdx.x - off_m / 128; + size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize; + size_t offset = idx_m * K + off_k; + + for (size_t j = 0; j < 4; j += VecSize) { + if (off_k + j * 32 < K) { + size_t idx = offset + j * 32; + using StoreT = VecType; + StoreT data; + for (int k = 0; k < VecSize; k++) { + data[k] = scale_out[j + k]; + } + *reinterpret_cast(scale + idx) = data; + } + } + } +} + +template +__device__ void BlockStoreOut(OutT* out, + size_t off_m, + size_t cur_tokens, + const OutT shm[128][129], + size_t K) { + for (size_t i = 0; i < 4; i++) { + size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4; + size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32; + size_t idx = idx_k * cur_tokens + (idx_m - off_m); + + if (idx_k < K) { + using StoreT = VecType; + StoreT data; + for (int j = 0; j < VecSize; j++) { + data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j]; + } + *reinterpret_cast(out + idx) = data; + } + } +} + +__device__ std::pair GetExpertIdx(int64_t* tokens_per_expert, + size_t num_experts) { + __shared__ size_t expert_idx_, off_m_; + + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t idx_m = blockIdx.x * 128; + size_t off_m = 0, next_off_m = 0; + size_t expert_idx; + for (expert_idx = 0; expert_idx < num_experts; expert_idx++) { + next_off_m += tokens_per_expert[expert_idx]; + if (idx_m >= off_m && idx_m < next_off_m) { + break; + } + off_m = next_off_m; + } + expert_idx_ = expert_idx; + off_m_ = off_m; + } + __syncthreads(); + + return {expert_idx_, off_m_}; +} + +template +__global__ void __launch_bounds__(1024) + FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ X, + int64_t* __restrict__ meta, + size_t num_experts, + size_t K) { + __shared__ OutT shm[128][129]; + int64_t* tokens_per_expert = meta; + OutT** out_ptrs = reinterpret_cast(meta + num_experts); + float** scale_ptrs = reinterpret_cast(meta + num_experts * 2); + + // Get expert_idx and offset at the M dim of the current block + auto expert_info = GetExpertIdx(tokens_per_expert, num_experts); + size_t expert_idx = expert_info.first; + size_t off_m = expert_info.second; + + // Load 128x128 elements from X + __nv_bfloat16 input[4][4]; + BlockLoad(X, input, K); + + // Find the maximum of each 128 elements on the M axis + __nv_bfloat16 amax[4]; + BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm)); + + // Compute scale and scale_inv, then store scale back + float scale_inv[4]; + BlockStoreScale( + scale_ptrs[expert_idx], off_m, amax, scale_inv, K); + + // Scale X and save into shared memory with transposed layout + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j += VecSize) { + for (int k = 0; k < VecSize; k++) { + float input_fp32 = static_cast(input[i][j + k]); + float output_scaled = input_fp32 * scale_inv[j + k]; + shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] = + static_cast(output_scaled); + } + } + } + __syncthreads(); + + // Store 128x128 elements back + // Note: out is always 4x vectorizable. + BlockStoreOut( + out_ptrs[expert_idx], off_m, tokens_per_expert[expert_idx], shm, K); +} + +/** + * Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then + * split the result into out and scale. + * + * Inputs: + * X : [SUM(M_1...M_N), K], bfloat16 + * + * Outputs: + * out : {[K, M_1], [K, M_2], ..., [K, M_N]}, float8_e4m3fn + * scale : {[M_1/128, K], [M_2/128, K], ..., [M_N/128, K]}, float + * + * Attrs: + * pow_2_scales + * : bool that indicates whether to use power-of-2 scaling + * + * Requirements: + * 1) M_i % 128 == 0 for M_i in [M_1, M_2, ..., M_N] + * 2) K <= 65535 * 128 + */ +void fused_transpose_split_quant(const paddle::Tensor& X, + std::vector& outs, + std::vector& scales, + bool pow_2_scales) { + // Check X + PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16); + + std::vector shape = X.shape(); + PD_CHECK(shape.size() == 2); + const int64_t M = shape[0]; + const int64_t K = shape[1]; + + // Check outs and scales + const size_t num_experts = outs.size(); + PD_CHECK(scales.size() == num_experts); + + std::vector tokens_per_expert; + int64_t sum_tokens = 0; + for (size_t i = 0; i < num_experts; i++) { + PD_CHECK(outs[i].dtype() == paddle::DataType::FLOAT8_E4M3FN); + PD_CHECK(scales[i].dtype() == paddle::DataType::FLOAT32); + + std::vector out_shape = outs[i].shape(); + PD_CHECK(out_shape.size() == 2); + PD_CHECK(out_shape[0] == K); + PD_CHECK(out_shape[1] % 128 == 0); + tokens_per_expert.push_back(out_shape[1]); + sum_tokens += out_shape[1]; + + std::vector scale_shape = scales[i].shape(); + PD_CHECK(scale_shape.size() == 2); + PD_CHECK(scale_shape[0] == out_shape[1] / 128); + PD_CHECK(scale_shape[1] == K); + } + + PD_CHECK(sum_tokens == M, + "sum of out[i].shape[1] must be equal to X.shape[0]"); + PD_CHECK(K <= 65535 * 128, "only supports K <= 65535 * 128"); + + // Skip 0-size + if (M == 0 || K == 0) { + return; + } + + // Copy meta (tokens_per_expert, out_ptrs, scale_ptrs) to device + paddle::Tensor meta_cpu = paddle::empty( + {static_cast(num_experts * 3)}, paddle::DataType::INT64); + int64_t* meta_ptr = meta_cpu.data(); + for (size_t i = 0; i < num_experts; i++) { + meta_ptr[i] = static_cast(tokens_per_expert[i]); + } + for (size_t i = 0; i < num_experts; i++) { + meta_ptr[num_experts + i] = + reinterpret_cast(outs[i].data()); + } + for (size_t i = 0; i < num_experts; i++) { + meta_ptr[num_experts * 2 + i] = + reinterpret_cast(scales[i].data()); + } + paddle::Tensor meta_gpu = meta_cpu.copy_to(X.place(), /*blocking=*/false); + + // Launch kernel + dim3 grid(M / 128, (K + 127) / 128); + dim3 block(32, 32); + +#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \ + FusedTransposeSplitQuantKernel \ + <<>>( \ + X.data(), meta_gpu.data(), num_experts, K); +#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \ + if (pow_2_scales) { \ + LAUNCH_KERNEL(true, VEC_SIZE); \ + } else { \ + LAUNCH_KERNEL(false, VEC_SIZE); \ + } + + if (K % 4 == 0) { + LAUNCH_KERNEL_PARTIAL(4); + } else if (K % 2 == 0) { + LAUNCH_KERNEL_PARTIAL(2); + } else { + LAUNCH_KERNEL_PARTIAL(1); + } +#undef LAUNCH_KERNEL_PARTIAL +#undef LAUNCH_KERNEL +} + +PD_BUILD_OP(fused_transpose_split_quant) + .Inputs({"X", paddle::Vec("outs"), paddle::Vec("scales")}) + .Attrs({"pow_2_scales: bool"}) + .SetKernelFn(PD_KERNEL(fused_transpose_split_quant)); diff --git a/slm/model_zoo/gpt-3/external_ops/setup_fp8.py b/slm/model_zoo/gpt-3/external_ops/setup_fp8.py index 8f243e62e8a5..94af76f78122 100644 --- a/slm/model_zoo/gpt-3/external_ops/setup_fp8.py +++ b/slm/model_zoo/gpt-3/external_ops/setup_fp8.py @@ -42,6 +42,7 @@ def setup_fused_quant_ops(): "fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu", "fused_quanted_ops/fused_spaq.cu", "fused_quanted_ops/fused_stack_transpose_quant.cu", + "fused_quanted_ops/fused_transpose_split_quant.cu", ], extra_compile_args={ "cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"], diff --git a/tests/ops/test_fused_transpose_split_quant.py b/tests/ops/test_fused_transpose_split_quant.py new file mode 100644 index 000000000000..9ad1b6283b52 --- /dev/null +++ b/tests/ops/test_fused_transpose_split_quant.py @@ -0,0 +1,52 @@ +import FusedQuantOps as FQO +import numpy as np + +import paddle + + +def restore_transpose_split_quant(out, scale): + out = [t.astype('float32') for t in out] + out = paddle.concat(out, axis=1).transpose([1, 0]) + scale = paddle.concat(scale, axis=0) + scale = paddle.repeat_interleave(scale, repeats=128, axis=0) + return out * scale + + +def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales): + print(tokens_per_expert, seq_len, pow_2_scales) + + x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16') + x = paddle.clip(x, min=-50, max=50) + + out, scale = [], [] + for tokens in tokens_per_expert: + out.append(paddle.empty([seq_len, tokens], dtype='float8_e4m3fn')) + scale.append(paddle.empty([tokens//128, seq_len], dtype='float32')) + + FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales) + + x_restore = restore_transpose_split_quant(out, scale) + x_cast = x.astype('float32') + + np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3) + + +def run(): + test_fused_transpose_split_quant([0, 0], 1024, False) + test_fused_transpose_split_quant([128, 2*128], 0, True) + test_fused_transpose_split_quant([128], 1, False) + test_fused_transpose_split_quant([0, 128, 0, 2*128], 127, True) + test_fused_transpose_split_quant([3*128, 4*128, 5*128], 233, False) + test_fused_transpose_split_quant( + [24*128, 128, 50*128, 16*128], 2162, True + ) + test_fused_transpose_split_quant( + [7*128, 29*128, 3*128, 128*128, 13*128], 4000, False + ) + test_fused_transpose_split_quant( + [18*128, 5*128, 24*128, 128, 6*128, 0, 27*128, 7*128], 7168, True + ) + + +if __name__ == '__main__': + run()