|
| 1 | +#include "quant_utils.h" |
| 2 | + |
| 3 | +template <typename T, int VecSize> |
| 4 | +struct __align__(sizeof(T) * VecSize) VecType { |
| 5 | + T val[VecSize]; |
| 6 | + __host__ __device__ inline T& operator[](size_t i) { return val[i]; } |
| 7 | + __host__ __device__ inline const T& operator[](size_t i) const { |
| 8 | + return val[i]; |
| 9 | + } |
| 10 | +}; |
| 11 | + |
| 12 | +template <int VecSize> |
| 13 | +__device__ void BlockLoad(const phi::bfloat16* X, |
| 14 | + __nv_bfloat16 input[4][4], |
| 15 | + size_t M, |
| 16 | + size_t K) { |
| 17 | + for (size_t i = 0; i < 4; i++) { |
| 18 | + size_t off_n = blockIdx.z; |
| 19 | + size_t off_m = blockIdx.y * 128 + threadIdx.y + i * 32; |
| 20 | + size_t off_k = blockIdx.x * 128 + threadIdx.x * VecSize; |
| 21 | + size_t offset = (off_n * M + off_m) * K + off_k; |
| 22 | + |
| 23 | + for (size_t j = 0; j < 4; j += VecSize) { |
| 24 | + if (off_k + j * 32 < K) { |
| 25 | + size_t idx = offset + j * 32; |
| 26 | + using LoadT = VecType<__nv_bfloat16, VecSize>; |
| 27 | + LoadT data = *reinterpret_cast<const LoadT*>(X + idx); |
| 28 | + for (int k = 0; k < VecSize; k++) { |
| 29 | + input[i][j + k] = data[k]; |
| 30 | + } |
| 31 | + } |
| 32 | + } |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4], |
| 37 | + __nv_bfloat16 amax[4], |
| 38 | + __nv_bfloat16* shm) { |
| 39 | + // Reduce [(4), 32, 32, 4] => [32, 32, 4] |
| 40 | + __nv_bfloat16 warp_max[4]; |
| 41 | + for (int i = 0; i < 4; i++) { |
| 42 | + for (int j = 0; j < 4; j++) { |
| 43 | + __nv_bfloat16 t = __habs(input[i][j]); |
| 44 | + warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t); |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + // Reduce [(32), 32, 4] => [32, 4] |
| 49 | + for (int i = 0; i < 4; i++) { |
| 50 | + shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i]; |
| 51 | + } |
| 52 | + __syncthreads(); |
| 53 | + for (int offset = 16; offset > 0; offset /= 2) { |
| 54 | + if (threadIdx.y < offset) { |
| 55 | + for (int i = 0; i < 4; i++) { |
| 56 | + shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = |
| 57 | + __hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x], |
| 58 | + shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]); |
| 59 | + } |
| 60 | + } |
| 61 | + __syncthreads(); |
| 62 | + } |
| 63 | + |
| 64 | + for (int i = 0; i < 4; i++) { |
| 65 | + amax[i] = shm[i * 32 + threadIdx.x]; |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +template <typename OutT, int VecSize> |
| 70 | +__device__ void BlockStoreScale(float* scale, |
| 71 | + __nv_bfloat16 amax[4], |
| 72 | + float scale_inv[4], |
| 73 | + size_t M, |
| 74 | + size_t K) { |
| 75 | + float scale_out[4]; |
| 76 | + for (int i = 0; i < 4; i++) { |
| 77 | + scale_out[i] = ComputeScale<__nv_bfloat16, OutT>(amax[i], 0.0f); |
| 78 | + scale_inv[i] = __frcp_rn(scale_out[i]); |
| 79 | + } |
| 80 | + if (threadIdx.y == 0) { |
| 81 | + size_t off_n = blockIdx.z; |
| 82 | + size_t off_m = blockIdx.y; |
| 83 | + size_t off_k = blockIdx.x * 128 + threadIdx.x * VecSize; |
| 84 | + size_t offset = (off_n * (M / 128) + off_m) * K + off_k; |
| 85 | + |
| 86 | + for (size_t j = 0; j < 4; j += VecSize) { |
| 87 | + if (off_k + j * 32 < K) { |
| 88 | + size_t idx = offset + j * 32; |
| 89 | + using StoreT = VecType<float, VecSize>; |
| 90 | + StoreT data; |
| 91 | + for (int k = 0; k < VecSize; k++) { |
| 92 | + data[k] = scale_out[j + k]; |
| 93 | + } |
| 94 | + *reinterpret_cast<StoreT*>(scale + idx) = data; |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +template <typename OutT, int VecSize> |
| 101 | +__device__ void BlockStoreOut(OutT* out, |
| 102 | + const OutT shm[128][129], |
| 103 | + size_t M, |
| 104 | + size_t K) { |
| 105 | + for (size_t i = 0; i < 4; i++) { |
| 106 | + size_t idx_n = blockIdx.z; |
| 107 | + size_t idx_k = blockIdx.x * 128 + threadIdx.y + i * 32; |
| 108 | + size_t idx_m = blockIdx.y * 128 + threadIdx.x * 4; |
| 109 | + size_t idx = (idx_n * K + idx_k) * M + idx_m; |
| 110 | + |
| 111 | + if (idx_k < K) { |
| 112 | + using StoreT = VecType<OutT, VecSize>; |
| 113 | + StoreT data; |
| 114 | + for (int j = 0; j < VecSize; j++) { |
| 115 | + data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j]; |
| 116 | + } |
| 117 | + *reinterpret_cast<StoreT*>(out + idx) = data; |
| 118 | + } |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +template <typename OutT, int VecSize> |
| 123 | +__global__ void __launch_bounds__(1024, 2) |
| 124 | + FusedTransposeQuantKernel(const phi::bfloat16* __restrict__ X, |
| 125 | + OutT* __restrict__ out, |
| 126 | + float* __restrict__ scale, |
| 127 | + size_t M, |
| 128 | + size_t K) { |
| 129 | + __shared__ OutT shm[128][129]; |
| 130 | + |
| 131 | + // Load 128x128 elements from X |
| 132 | + __nv_bfloat16 input[4][4]; |
| 133 | + BlockLoad<VecSize>(X, input, M, K); |
| 134 | + |
| 135 | + // Find the maximum of each 128 elements on the M axis |
| 136 | + __nv_bfloat16 amax[4]; |
| 137 | + BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm)); |
| 138 | + |
| 139 | + // Compute scale and scale_inv, save scale to output |
| 140 | + float scale_inv[4]; |
| 141 | + BlockStoreScale<OutT, VecSize>(scale, amax, scale_inv, M, K); |
| 142 | + |
| 143 | + // Scale X and save into shared memory with transposed layout |
| 144 | + for (int i = 0; i < 4; i++) { |
| 145 | + for (int j = 0; j < 4; j += VecSize) { |
| 146 | + for (int k = 0; k < VecSize; k++) { |
| 147 | + float input_fp32 = static_cast<float>(input[i][j + k]); |
| 148 | + float output_scaled = input_fp32 * scale_inv[j + k]; |
| 149 | + shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] = |
| 150 | + static_cast<OutT>(output_scaled); |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + __syncthreads(); |
| 155 | + |
| 156 | + // Store 128x128 elements back |
| 157 | + // Note: out is always 4x vectorizable. |
| 158 | + BlockStoreOut<OutT, 4>(out, shm, M, K); |
| 159 | +} |
| 160 | + |
| 161 | +/** |
| 162 | + * Doing quantization on dim[-2] of X, then transpose dim[-1] and dim[-2] of X. |
| 163 | + * |
| 164 | + * Inputs: |
| 165 | + * X : [*, M, K], bfloat16 |
| 166 | + * |
| 167 | + * Outputs: |
| 168 | + * out : [*, K, M], float8_e4m3fn |
| 169 | + * scale: [*, M/128, K], float32 |
| 170 | + * |
| 171 | + * Requirements: |
| 172 | + * 1) batch_size <= 65535 |
| 173 | + * 2) M <= 65535 * 128 and M % 128 == 0 |
| 174 | + */ |
| 175 | +std::vector<paddle::Tensor> fused_transpose_quant(const paddle::Tensor& X) { |
| 176 | + PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16); |
| 177 | + |
| 178 | + std::vector<int64_t> shape = X.shape(); |
| 179 | + PD_CHECK(shape.size() >= 2); |
| 180 | + |
| 181 | + int64_t M = shape[shape.size() - 2]; |
| 182 | + int64_t K = shape[shape.size() - 1]; |
| 183 | + int64_t N = X.numel() / (M * K); |
| 184 | + |
| 185 | + PADDLE_ENFORCE_LE( |
| 186 | + N, |
| 187 | + 65535, |
| 188 | + common::errors::InvalidArgument("The batch size (X.shape[0:-2] in total) " |
| 189 | + "must be no larger than 65535.")); |
| 190 | + PADDLE_ENFORCE_LE(M, |
| 191 | + 65535 * 128, |
| 192 | + common::errors::InvalidArgument( |
| 193 | + "X.shape[-2] must be no larger than 65535 * 128.")); |
| 194 | + PADDLE_ENFORCE_EQ( |
| 195 | + M % 128, |
| 196 | + 0, |
| 197 | + common::errors::InvalidArgument("X.shape[-2] must be multiple of 128.")); |
| 198 | + |
| 199 | + // Allocate for out and scale |
| 200 | + std::vector<int64_t> out_shape = shape; |
| 201 | + out_shape[shape.size() - 2] = K; |
| 202 | + out_shape[shape.size() - 1] = M; |
| 203 | + paddle::Tensor out = |
| 204 | + paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, X.place()); |
| 205 | + |
| 206 | + std::vector<int64_t> scale_shape = shape; |
| 207 | + scale_shape[shape.size() - 2] = M / 128; |
| 208 | + paddle::Tensor scale = |
| 209 | + paddle::empty(scale_shape, paddle::DataType::FLOAT32, X.place()); |
| 210 | + |
| 211 | + // Skip 0-size |
| 212 | + if (N == 0 || M == 0 || K == 0) { |
| 213 | + return {out, scale}; |
| 214 | + } |
| 215 | + |
| 216 | + // Launch kernel |
| 217 | + dim3 grid((K + 127) / 128, M / 128, N); |
| 218 | + dim3 block(32, 32); |
| 219 | + |
| 220 | +#define LAUNCH_KERNEL(VEC_SIZE) \ |
| 221 | + FusedTransposeQuantKernel<phi::float8_e4m3fn, VEC_SIZE> \ |
| 222 | + <<<grid, block>>>(X.data<phi::bfloat16>(), \ |
| 223 | + out.data<phi::float8_e4m3fn>(), \ |
| 224 | + scale.data<float>(), \ |
| 225 | + M, \ |
| 226 | + K); |
| 227 | + if (K % 4 == 0) { |
| 228 | + LAUNCH_KERNEL(4); |
| 229 | + } else if (K % 2 == 0) { |
| 230 | + LAUNCH_KERNEL(2); |
| 231 | + } else { |
| 232 | + LAUNCH_KERNEL(1); |
| 233 | + } |
| 234 | +#undef LAUNCH_KERNEL |
| 235 | + |
| 236 | + return {out, scale}; |
| 237 | +} |
| 238 | + |
| 239 | +PD_BUILD_OP(fused_transpose_quant) |
| 240 | + .Inputs({"X"}) |
| 241 | + .Outputs({"output", "scale"}) |
| 242 | + .SetKernelFn(PD_KERNEL(fused_transpose_quant)); |
0 commit comments