|
| 1 | +#include "quant_utils.h" |
| 2 | + |
| 3 | +template <typename T, int VecSize> |
| 4 | +struct __align__(sizeof(T) * VecSize) VecType { |
| 5 | + T val[VecSize]; |
| 6 | + __host__ __device__ inline T& operator[](size_t i) { return val[i]; } |
| 7 | + __host__ __device__ inline const T& operator[](size_t i) const { |
| 8 | + return val[i]; |
| 9 | + } |
| 10 | +}; |
| 11 | + |
| 12 | +struct FastDiv { |
| 13 | + FastDiv() {} |
| 14 | + FastDiv(uint64_t d) { |
| 15 | + for (shift_val = 0; shift_val < 64; ++shift_val) { |
| 16 | + uint64_t shift_limit = uint64_t(1) << shift_val; |
| 17 | + if (shift_limit >= d) break; |
| 18 | + } |
| 19 | + |
| 20 | + // quotient = ((uint128_t)n_hi << 64) / d |
| 21 | + uint64_t quotient = 0; |
| 22 | + uint64_t n_hi = (uint64_t(1) << shift_val) - d, n_lo = 0; |
| 23 | + for (int i = 63; i >= 0; --i) { |
| 24 | + uint64_t d_hi = i == 0 ? 0 : d >> (64 - i); |
| 25 | + uint64_t d_lo = d << i; |
| 26 | + if (n_hi == 0 && n_lo == 0) break; |
| 27 | + if ((d_hi < n_hi) || (d_hi <= n_hi && d_lo <= n_lo)) { |
| 28 | + quotient |= uint64_t(1) << i; |
| 29 | + n_hi -= d_hi + (d_lo > n_lo); |
| 30 | + n_lo -= d_lo; |
| 31 | + } |
| 32 | + } |
| 33 | + multiplier = quotient + 1; |
| 34 | + } |
| 35 | + |
| 36 | + __device__ uint64_t Div(uint64_t n) const { |
| 37 | + uint64_t t = __umul64hi(n, multiplier); |
| 38 | + return (t + n) >> shift_val; |
| 39 | + } |
| 40 | + |
| 41 | + int shift_val; |
| 42 | + uint64_t multiplier; |
| 43 | +}; |
| 44 | + |
| 45 | +__device__ void BlockLoad(const int64_t* __restrict__ X_ptrs, |
| 46 | + __nv_bfloat16 input[4][4], |
| 47 | + size_t K, |
| 48 | + size_t block_y, |
| 49 | + size_t block_x) { |
| 50 | + const __nv_bfloat16* X = |
| 51 | + reinterpret_cast<const __nv_bfloat16*>(X_ptrs[blockIdx.z]); |
| 52 | + |
| 53 | + for (size_t i = 0; i < 4; i++) { |
| 54 | + size_t idx_m = block_y * 128 + threadIdx.y + i * 32; |
| 55 | + size_t idx_k = block_x * 128 + threadIdx.x * 4; |
| 56 | + size_t idx = idx_m * K + idx_k; |
| 57 | + |
| 58 | + using LoadT = VecType<__nv_bfloat16, 4>; |
| 59 | + LoadT data = *reinterpret_cast<const LoadT*>(X + idx); |
| 60 | + for (int j = 0; j < 4; j++) { |
| 61 | + input[i][j] = data[j]; |
| 62 | + } |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +__device__ __nv_bfloat16 WarpReduceMax(__nv_bfloat16 x) { |
| 67 | + for (int offset = 16; offset > 0; offset /= 2) { |
| 68 | + __nv_bfloat16 t = __shfl_down_sync(0xffffffff, x, offset); |
| 69 | + x = __hmax(x, t); |
| 70 | + } |
| 71 | + return x; |
| 72 | +} |
| 73 | + |
| 74 | +__device__ __nv_bfloat16 BlockReduceMax(__nv_bfloat16 input[4][4]) { |
| 75 | + // [(4), 32, 32, (4)] => [32, 32] |
| 76 | + __nv_bfloat16 local_max; |
| 77 | + for (int i = 0; i < 4; i++) { |
| 78 | + for (int j = 0; j < 4; j++) { |
| 79 | + __nv_bfloat16 t = __habs(input[i][j]); |
| 80 | + local_max = (i == 0 && j == 0) ? t : __hmax(local_max, t); |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + // [32, (32)] => [32] |
| 85 | + __nv_bfloat16 warp_max = WarpReduceMax(local_max); |
| 86 | + |
| 87 | + // [(32)] => [1] |
| 88 | + __shared__ __nv_bfloat16 block_max[32]; |
| 89 | + if (threadIdx.x == 0) { |
| 90 | + block_max[threadIdx.y] = warp_max; |
| 91 | + } |
| 92 | + __syncthreads(); |
| 93 | + if (threadIdx.y == 0) { |
| 94 | + warp_max = WarpReduceMax(block_max[threadIdx.x]); |
| 95 | + if (threadIdx.x == 0) { |
| 96 | + block_max[0] = warp_max; |
| 97 | + } |
| 98 | + } |
| 99 | + __syncthreads(); |
| 100 | + |
| 101 | + return block_max[0]; |
| 102 | +} |
| 103 | + |
| 104 | +template <typename OutT> |
| 105 | +__global__ void __launch_bounds__(1024) |
| 106 | + FusedStackQuantKernel(const int64_t* __restrict__ X_ptrs, |
| 107 | + OutT* __restrict__ out, |
| 108 | + float* __restrict__ scale, |
| 109 | + size_t M, |
| 110 | + size_t K, |
| 111 | + FastDiv K_div_128) { |
| 112 | + size_t block_y = K_div_128.Div(blockIdx.x); |
| 113 | + size_t block_x = blockIdx.x - block_y * (K / 128); |
| 114 | + |
| 115 | + // Load 128x128 elements from X |
| 116 | + __nv_bfloat16 input[4][4]; |
| 117 | + BlockLoad(X_ptrs, input, K, block_y, block_x); |
| 118 | + |
| 119 | + // Find the maximum in all elements |
| 120 | + __nv_bfloat16 amax = BlockReduceMax(input); |
| 121 | + |
| 122 | + // Compute scale and store back |
| 123 | + float scale_inv = ComputeScale<__nv_bfloat16, OutT>(amax, 0.0f); |
| 124 | + float scale_out = __frcp_rn(scale_inv); |
| 125 | + if (threadIdx.x == 0 && threadIdx.y == 0) { |
| 126 | + size_t idx_n = blockIdx.z; |
| 127 | + size_t idx_m = block_y; |
| 128 | + size_t idx_k = block_x; |
| 129 | + size_t idx = (idx_n * (M / 128) + idx_m) * (K / 128) + idx_k; |
| 130 | + scale[idx] = scale_out; |
| 131 | + } |
| 132 | + |
| 133 | + // Scale X and store to out |
| 134 | + for (size_t i = 0; i < 4; i++) { |
| 135 | + size_t idx_n = blockIdx.z; |
| 136 | + size_t idx_m = block_y * 128 + threadIdx.y + i * 32; |
| 137 | + size_t idx_k = block_x * 128 + threadIdx.x * 4; |
| 138 | + size_t idx = (idx_n * M + idx_m) * K + idx_k; |
| 139 | + |
| 140 | + using StoreT = VecType<OutT, 4>; |
| 141 | + StoreT data; |
| 142 | + for (int j = 0; j < 4; j++) { |
| 143 | + float input_fp32 = static_cast<float>(input[i][j]); |
| 144 | + float output_scaled = input_fp32 * scale_inv; |
| 145 | + data[j] = static_cast<OutT>(output_scaled); |
| 146 | + } |
| 147 | + *reinterpret_cast<StoreT*>(out + idx) = data; |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +template <typename OutT> |
| 152 | +__global__ void __launch_bounds__(1024) |
| 153 | + FusedStackTransposeQuantKernel(const int64_t* __restrict__ X_ptrs, |
| 154 | + OutT* __restrict__ out, |
| 155 | + float* __restrict__ scale, |
| 156 | + size_t M, |
| 157 | + size_t K, |
| 158 | + FastDiv K_div_128) { |
| 159 | + size_t block_y = K_div_128.Div(blockIdx.x); |
| 160 | + size_t block_x = blockIdx.x - block_y * (K / 128); |
| 161 | + |
| 162 | + // Load 128x128 elements from X |
| 163 | + __nv_bfloat16 input[4][4]; |
| 164 | + BlockLoad(X_ptrs, input, K, block_y, block_x); |
| 165 | + |
| 166 | + // Find the maximum in all elements |
| 167 | + __nv_bfloat16 amax = BlockReduceMax(input); |
| 168 | + |
| 169 | + // Compute scale and store back |
| 170 | + float scale_inv = ComputeScale<__nv_bfloat16, OutT>(amax, 0.0f); |
| 171 | + float scale_out = __frcp_rn(scale_inv); |
| 172 | + if (threadIdx.x == 0 && threadIdx.y == 0) { |
| 173 | + size_t idx_n = blockIdx.z; |
| 174 | + size_t idx_k = block_x; |
| 175 | + size_t idx_m = block_y; |
| 176 | + size_t idx = (idx_n * (K / 128) + idx_k) * (M / 128) + idx_m; |
| 177 | + scale[idx] = scale_out; |
| 178 | + } |
| 179 | + |
| 180 | + // Scale X and transpose in shared memory |
| 181 | + __shared__ OutT shm[128][129]; |
| 182 | + for (int i = 0; i < 4; i++) { |
| 183 | + for (int j = 0; j < 4; j++) { |
| 184 | + float input_fp32 = static_cast<float>(input[i][j]); |
| 185 | + float output_scaled = input_fp32 * scale_inv; |
| 186 | + shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] = |
| 187 | + static_cast<OutT>(output_scaled); |
| 188 | + } |
| 189 | + } |
| 190 | + __syncthreads(); |
| 191 | + |
| 192 | + // Store X back to out |
| 193 | + for (size_t i = 0; i < 4; i++) { |
| 194 | + size_t idx_n = blockIdx.z; |
| 195 | + size_t idx_k = block_x * 128 + threadIdx.y + i * 32; |
| 196 | + size_t idx_m = block_y * 128 + threadIdx.x * 4; |
| 197 | + size_t idx = (idx_n * K + idx_k) * M + idx_m; |
| 198 | + |
| 199 | + using StoreT = VecType<OutT, 4>; |
| 200 | + StoreT data; |
| 201 | + for (int j = 0; j < 4; j++) { |
| 202 | + data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j]; |
| 203 | + } |
| 204 | + *reinterpret_cast<StoreT*>(out + idx) = data; |
| 205 | + } |
| 206 | +} |
| 207 | + |
| 208 | +/** |
| 209 | + * Stack tensors in X, optionally transpose dim[-1] and dim[-2], and do |
| 210 | + * quantization on both dim[-1] and dim[-2]. |
| 211 | + * |
| 212 | + * Inputs: |
| 213 | + * X : N tensors of [M, K], bfloat16 |
| 214 | + * |
| 215 | + * Outputs: |
| 216 | + * if Transpose: |
| 217 | + * out : [N * K, M], float8_e4m3fn |
| 218 | + * scale: [N * K / 128, M / 128], float |
| 219 | + * else: |
| 220 | + * out : [N * M, K], float8_e4m3fn |
| 221 | + * scale: [N * M / 128, K / 128], float |
| 222 | + * |
| 223 | + * Requirements: |
| 224 | + * 1) N <= 65535 |
| 225 | + * 2) M % 128 == 0 |
| 226 | + * 3) K % 128 == 0 |
| 227 | + */ |
| 228 | +template <bool Transpose> |
| 229 | +std::vector<paddle::Tensor> fused_stack_transpose_quant( |
| 230 | + const std::vector<paddle::Tensor>& X) { |
| 231 | + int64_t N = X.size(); |
| 232 | + PD_CHECK(N > 0); |
| 233 | + for (int64_t i = 0; i < N; i++) { |
| 234 | + PD_CHECK(X[i].dtype() == paddle::DataType::BFLOAT16); |
| 235 | + } |
| 236 | + |
| 237 | + std::vector<int64_t> shape = X[0].shape(); |
| 238 | + PD_CHECK(shape.size() == 2); |
| 239 | + int64_t M = shape[0]; |
| 240 | + int64_t K = shape[1]; |
| 241 | + |
| 242 | + for (int64_t i = 1; i < N; i++) { |
| 243 | + std::vector<int64_t> shape = X[i].shape(); |
| 244 | + PD_CHECK(shape.size() == 2); |
| 245 | + PD_CHECK(shape[0] == M); |
| 246 | + PD_CHECK(shape[1] == K); |
| 247 | + } |
| 248 | + |
| 249 | + PADDLE_ENFORCE_LE(N, |
| 250 | + 65535, |
| 251 | + common::errors::InvalidArgument( |
| 252 | + "The batch size (N) must be no larger than 65535.")); |
| 253 | + PADDLE_ENFORCE_EQ(M % 128, |
| 254 | + 0, |
| 255 | + common::errors::InvalidArgument( |
| 256 | + "The upper dim (M) must be multiple of 128.")); |
| 257 | + PADDLE_ENFORCE_EQ(K % 128, |
| 258 | + 0, |
| 259 | + common::errors::InvalidArgument( |
| 260 | + "The lower dim (K) must be multiple of 128.")); |
| 261 | + |
| 262 | + // Allocate for out and scale |
| 263 | + std::vector<int64_t> out_shape, scale_shape; |
| 264 | + if (Transpose) { |
| 265 | + out_shape = {N * K, M}; |
| 266 | + scale_shape = {N * K / 128, M / 128}; |
| 267 | + } else { |
| 268 | + out_shape = {N * M, K}; |
| 269 | + scale_shape = {N * M / 128, K / 128}; |
| 270 | + } |
| 271 | + |
| 272 | + const auto& place = X[0].place(); |
| 273 | + paddle::Tensor out = |
| 274 | + paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, place); |
| 275 | + paddle::Tensor scale = |
| 276 | + paddle::empty(scale_shape, paddle::DataType::FLOAT32, place); |
| 277 | + |
| 278 | + // Skip 0-size |
| 279 | + if (M == 0 || K == 0) { |
| 280 | + return {out, scale}; |
| 281 | + } |
| 282 | + |
| 283 | + // Copy the pointers in X to device |
| 284 | + paddle::Tensor X_ptrs_cpu = paddle::empty({N}, paddle::DataType::INT64); |
| 285 | + int64_t* X_ptrs_data = X_ptrs_cpu.data<int64_t>(); |
| 286 | + for (int64_t i = 0; i < N; i++) { |
| 287 | + X_ptrs_data[i] = reinterpret_cast<int64_t>(X[i].data()); |
| 288 | + } |
| 289 | + paddle::Tensor X_ptrs_gpu = X_ptrs_cpu.copy_to(place, /* blocking= */ false); |
| 290 | + |
| 291 | + // Launch kernel |
| 292 | + dim3 grid((M / 128) * (K / 128), 1, N); |
| 293 | + dim3 block(32, 32); |
| 294 | + |
| 295 | +#define LAUNCH_KERNEL(KERNEL) \ |
| 296 | + KERNEL<<<grid, block>>>(X_ptrs_gpu.data<int64_t>(), \ |
| 297 | + out.data<phi::float8_e4m3fn>(), \ |
| 298 | + scale.data<float>(), \ |
| 299 | + M, \ |
| 300 | + K, \ |
| 301 | + FastDiv(K / 128)) |
| 302 | + if (Transpose) { |
| 303 | + LAUNCH_KERNEL(FusedStackTransposeQuantKernel); |
| 304 | + } else { |
| 305 | + LAUNCH_KERNEL(FusedStackQuantKernel); |
| 306 | + } |
| 307 | +#undef LAUNCH_KERNEL |
| 308 | + |
| 309 | + return {out, scale}; |
| 310 | +} |
| 311 | + |
| 312 | +PD_BUILD_OP(fused_stack_quant) |
| 313 | + .Inputs({paddle::Vec("X")}) |
| 314 | + .Outputs({"output", "scale"}) |
| 315 | + .SetKernelFn(PD_KERNEL(fused_stack_transpose_quant<false>)); |
| 316 | + |
| 317 | +PD_BUILD_OP(fused_stack_transpose_quant) |
| 318 | + .Inputs({paddle::Vec("X")}) |
| 319 | + .Outputs({"output", "scale"}) |
| 320 | + .SetKernelFn(PD_KERNEL(fused_stack_transpose_quant<true>)); |
0 commit comments