|
| 1 | +#include "quant_utils.h" |
| 2 | + |
| 3 | +#define LAUNCH_FUSED_SPAQ(__using_pow2_scaling, __with_prob) \ |
| 4 | + do { \ |
| 5 | + auto kernel = FusedSPAQKernel<__using_pow2_scaling, __with_prob>; \ |
| 6 | + kernel<<<grid, block, 0, X.stream()>>>( \ |
| 7 | + X.data<phi::bfloat16>(), \ |
| 8 | + prob ? prob->data<float>() : nullptr, \ |
| 9 | + out.data<phi::float8_e4m3fn>(), \ |
| 10 | + scale.data<float>(), \ |
| 11 | + rows, \ |
| 12 | + cols); \ |
| 13 | + } while (0) |
| 14 | + |
| 15 | +#define LAUNCH_FUSED_SPAQ_VEC4(__using_pow2_scaling, __with_prob) \ |
| 16 | + do { \ |
| 17 | + auto kernel = FusedSPAQKernelVec4<__using_pow2_scaling, \ |
| 18 | + __with_prob, \ |
| 19 | + thread_per_block>; \ |
| 20 | + kernel<<<grid, block, 0, X.stream()>>>( \ |
| 21 | + X.data<phi::bfloat16>(), \ |
| 22 | + prob ? prob->data<float>() : nullptr, \ |
| 23 | + out.data<phi::float8_e4m3fn>(), \ |
| 24 | + scale.data<float>(), \ |
| 25 | + rows, \ |
| 26 | + cols,\ |
| 27 | + scale_cols); \ |
| 28 | + } while (0) |
| 29 | + |
| 30 | + |
| 31 | +typedef struct __align__(8) { |
| 32 | + __nv_bfloat16 x; |
| 33 | + __nv_bfloat16 y; |
| 34 | + __nv_bfloat16 z; |
| 35 | + __nv_bfloat16 w; |
| 36 | +} |
| 37 | +bfloat16x4_t; |
| 38 | + |
| 39 | +typedef struct __align__(4) { |
| 40 | + __nv_fp8_e4m3 x; |
| 41 | + __nv_fp8_e4m3 y; |
| 42 | + __nv_fp8_e4m3 z; |
| 43 | + __nv_fp8_e4m3 w; |
| 44 | +} |
| 45 | +fp8_e4m3x4_t; |
| 46 | + |
| 47 | +__device__ __forceinline__ float fast_swiglu(const __nv_bfloat16 x, |
| 48 | + const __nv_bfloat16 y) { |
| 49 | + const float x_f = __bfloat162float(x); |
| 50 | + const float y_f = __bfloat162float(y); |
| 51 | + const float silu = x_f * __frcp_rn(1.0f + __expf(-x_f)); |
| 52 | + const float result = silu * y_f; |
| 53 | + return result; |
| 54 | +} |
| 55 | +__device__ __forceinline__ float4 fast_swiglu_vec4(const bfloat16x4_t &lhs, |
| 56 | + const bfloat16x4_t &rhs) { |
| 57 | + const float x_f_x = __bfloat162float(lhs.x); |
| 58 | + const float x_f_y = __bfloat162float(lhs.y); |
| 59 | + const float x_f_z = __bfloat162float(lhs.z); |
| 60 | + const float x_f_w = __bfloat162float(lhs.w); |
| 61 | + |
| 62 | + const float y_f_x = __bfloat162float(rhs.x); |
| 63 | + const float y_f_y = __bfloat162float(rhs.y); |
| 64 | + const float y_f_z = __bfloat162float(rhs.z); |
| 65 | + const float y_f_w = __bfloat162float(rhs.w); |
| 66 | + |
| 67 | + const float silu_x = x_f_x * __frcp_rn(1.0f + __expf(-x_f_x)); |
| 68 | + const float silu_y = x_f_y * __frcp_rn(1.0f + __expf(-x_f_y)); |
| 69 | + const float silu_z = x_f_z * __frcp_rn(1.0f + __expf(-x_f_z)); |
| 70 | + const float silu_w = x_f_w * __frcp_rn(1.0f + __expf(-x_f_w)); |
| 71 | + |
| 72 | + return {silu_x * y_f_x, silu_y * y_f_y, silu_z * y_f_z, silu_w * y_f_w}; |
| 73 | +} |
| 74 | +__device__ __forceinline__ float amax_float4(const float4 &vec) { |
| 75 | + return fmaxf(fmaxf(fabsf(vec.x), fabsf(vec.y)), |
| 76 | + fmaxf(fabsf(vec.z), fabsf(vec.w))); |
| 77 | +} |
| 78 | + |
| 79 | +__device__ __forceinline__ fp8_e4m3x4_t |
| 80 | +scale_fp32x4_to_fp8x4(const float4 &vec, const float scale) { |
| 81 | + return {static_cast<__nv_fp8_e4m3>(vec.x * scale), |
| 82 | + static_cast<__nv_fp8_e4m3>(vec.y * scale), |
| 83 | + static_cast<__nv_fp8_e4m3>(vec.z * scale), |
| 84 | + static_cast<__nv_fp8_e4m3>(vec.w * scale)}; |
| 85 | +} |
| 86 | + |
| 87 | + |
| 88 | +template <bool using_pow2_scaling, bool with_prob, int thread_per_block> |
| 89 | +__global__ void FusedSPAQKernelVec4(const phi::bfloat16 *__restrict__ Xin, |
| 90 | + const float *__restrict__ prob, |
| 91 | + phi::float8_e4m3fn *__restrict__ out, |
| 92 | + float *__restrict__ scales, |
| 93 | + const int rows, |
| 94 | + const int cols, |
| 95 | + const int scale_cols) { |
| 96 | + constexpr int elements_per_thread = 4; |
| 97 | + constexpr int warp_size = 32; |
| 98 | + constexpr int warp_num = thread_per_block / warp_size; |
| 99 | + const int scale_stride = scale_cols; |
| 100 | + const int lane = threadIdx.x % warp_size; |
| 101 | + const int x_offset = threadIdx.x * elements_per_thread; |
| 102 | + const int in_y_idx = blockIdx.y; |
| 103 | + const int in_x_idx = blockIdx.x * blockDim.x * elements_per_thread + x_offset; |
| 104 | + const int src_idx = in_y_idx * cols + in_x_idx; |
| 105 | + const unsigned int mask = 0xffffffff; // whole warp mask |
| 106 | + float p_t0; |
| 107 | + if (in_x_idx >= cols / 2 || in_y_idx > rows) [[unlikely]] |
| 108 | + return; |
| 109 | + |
| 110 | + if constexpr(with_prob){ |
| 111 | + // Prefetch prob |
| 112 | + if(lane==0) p_t0 = prob[in_y_idx]; |
| 113 | + } |
| 114 | + |
| 115 | + const __nv_bfloat16 *X = reinterpret_cast<const __nv_bfloat16 *>(Xin); |
| 116 | + |
| 117 | + // Initialize activation storage |
| 118 | + float4 act_f32x4; |
| 119 | + bfloat16x4_t lhs_bf16x4, rhs_bf16x4; |
| 120 | + |
| 121 | + // Reinterpret input pointer as bfloat16x4_t* for vectorized loading |
| 122 | + const bfloat16x4_t *X_lhs_vec = |
| 123 | + reinterpret_cast<const bfloat16x4_t *>(X + src_idx); |
| 124 | + const bfloat16x4_t *X_rhs_vec = |
| 125 | + reinterpret_cast<const bfloat16x4_t *>(X + src_idx + cols / 2); |
| 126 | + |
| 127 | + lhs_bf16x4 = *X_lhs_vec; |
| 128 | + rhs_bf16x4 = *X_rhs_vec; |
| 129 | + |
| 130 | + act_f32x4 = fast_swiglu_vec4(lhs_bf16x4, rhs_bf16x4); |
| 131 | + |
| 132 | + if constexpr (with_prob) { |
| 133 | + // Warp level sync to avoid syncthreads |
| 134 | + const float p = __shfl_sync(mask, p_t0, 0); |
| 135 | + act_f32x4.x *= p; |
| 136 | + act_f32x4.y *= p; |
| 137 | + act_f32x4.z *= p; |
| 138 | + act_f32x4.w *= p; |
| 139 | + } |
| 140 | + |
| 141 | + // Phase 2: Block Reduction to find per-quant block absolute maxima |
| 142 | + // Compute absolute values |
| 143 | + float thread_amax = amax_float4(act_f32x4); |
| 144 | + |
| 145 | + // All-Reduce within the warp |
| 146 | + #pragma unroll |
| 147 | + for (int offset = 16; offset > 0; offset /= 2) { |
| 148 | + const float val = __shfl_down_sync(mask, thread_amax, offset); |
| 149 | + thread_amax = fmaxf(thread_amax, val); |
| 150 | + } |
| 151 | + const float final_amax = __shfl_sync(mask, thread_amax, 0); |
| 152 | + |
| 153 | + // Phase 3: Compute scales and quantize the outputs |
| 154 | + const float scale = |
| 155 | + ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>(final_amax, 0.0f); |
| 156 | + const float inv_scale = __frcp_rn(scale); |
| 157 | + |
| 158 | + const fp8_e4m3x4_t act_fp8x4 = scale_fp32x4_to_fp8x4(act_f32x4, scale); |
| 159 | + fp8_e4m3x4_t *const out_vec_addr = |
| 160 | + reinterpret_cast<fp8_e4m3x4_t *>(out + in_y_idx * cols / 2 + in_x_idx); |
| 161 | + *out_vec_addr = act_fp8x4; |
| 162 | + if (lane == 0) scales[in_y_idx * scale_stride + in_x_idx / 128] = inv_scale; |
| 163 | +} |
| 164 | + |
| 165 | +template <bool using_pow2_scaling, bool with_prob> |
| 166 | +__global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin, |
| 167 | + const float *__restrict__ prob, |
| 168 | + phi::float8_e4m3fn *__restrict__ out, |
| 169 | + float *__restrict__ scales, |
| 170 | + const int rows, |
| 171 | + const int cols) { |
| 172 | + // Configure shared memory |
| 173 | + __shared__ float smem_tile[256]; // Shared memory for activation values |
| 174 | + __shared__ float warp_max[2][4]; // Shared memory for warp maxima (2 quant |
| 175 | + // blocks x 4 warps) |
| 176 | + __shared__ __nv_bfloat16 |
| 177 | + quant_block_amax[2]; // Shared memory for quant block maxima |
| 178 | + |
| 179 | + const __nv_bfloat16 *X = reinterpret_cast<const __nv_bfloat16 *>(Xin); |
| 180 | + const int x_offset = threadIdx.x; |
| 181 | + const int quant_block_idx = |
| 182 | + threadIdx.x / 128; // 0 or 1, two quant blocks per block |
| 183 | + const int in_y_idx = blockIdx.y; |
| 184 | + const int in_x_idx = blockIdx.x * blockDim.x + x_offset; |
| 185 | + const int src_idx = in_y_idx * cols + in_x_idx; |
| 186 | + |
| 187 | + // Load data and compute swiGLU activation |
| 188 | + if (in_x_idx < cols / 2) [[likely]] { |
| 189 | + __nv_bfloat16 x1 = X[src_idx]; // First half of the input |
| 190 | + __nv_bfloat16 x2 = X[src_idx + cols / 2]; // Second half of the input |
| 191 | + |
| 192 | + if constexpr (with_prob) { |
| 193 | + float row_prob = prob[in_y_idx]; |
| 194 | + smem_tile[x_offset] = fast_swiglu(x1, x2) * row_prob; |
| 195 | + } else { |
| 196 | + smem_tile[x_offset] = fast_swiglu(x1, x2); |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + __syncthreads(); // Ensure all threads have loaded their data |
| 201 | + |
| 202 | + // Phase 2: Block Reduction to find per-quant block absolute maximums |
| 203 | + float local_max = (in_x_idx < (cols / 2)) ? fabsf(smem_tile[x_offset]) : 0.0f; |
| 204 | + |
| 205 | + |
| 206 | + // Warp-level reduction |
| 207 | + unsigned int mask = 0xffffffff; |
| 208 | + int lane = threadIdx.x % 32; |
| 209 | + int warp_id = |
| 210 | + (threadIdx.x % 128) / 32; // Warp ID within the quant block (0-3) |
| 211 | + |
| 212 | + // Reduce within the warp |
| 213 | + for (int offset = 16; offset > 0; offset /= 2) { |
| 214 | + float val = __shfl_down_sync(mask, local_max, offset); |
| 215 | + local_max = fmaxf(local_max, val); |
| 216 | + } |
| 217 | + |
| 218 | + // Store warp maxima |
| 219 | + if (lane == 0) { |
| 220 | + warp_max[quant_block_idx][warp_id] = local_max; |
| 221 | + } |
| 222 | + |
| 223 | + __syncthreads(); |
| 224 | + |
| 225 | + // Reduce warp maxima to get quant block maxima |
| 226 | + if (warp_id == 0 && lane < 4) { |
| 227 | + if (threadIdx.x < 256) { // Ensure only valid threads participate |
| 228 | + float block_max = warp_max[quant_block_idx][lane]; |
| 229 | + // Reduce over the 4 warp maxima |
| 230 | + if (lane == 0) { |
| 231 | + block_max = fmaxf(block_max, warp_max[quant_block_idx][1]); |
| 232 | + block_max = fmaxf(block_max, warp_max[quant_block_idx][2]); |
| 233 | + block_max = fmaxf(block_max, warp_max[quant_block_idx][3]); |
| 234 | + quant_block_amax[quant_block_idx] = __float2bfloat16(block_max); |
| 235 | + } |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + __syncthreads(); |
| 240 | + |
| 241 | + // Phase 3: Compute scales and quantize the outputs |
| 242 | + const float block_max_float = (float)quant_block_amax[quant_block_idx]; |
| 243 | + const int scale_stride = (cols / 2 + 127) / 128; |
| 244 | + |
| 245 | + float scale = ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>( |
| 246 | + block_max_float, 0.0f); |
| 247 | + float inv_scale = __frcp_rn(scale); |
| 248 | + |
| 249 | + // Quantize |
| 250 | + float output_scaled_fp32 = smem_tile[x_offset] * scale; |
| 251 | + |
| 252 | + |
| 253 | + const int g_output_y_offset = in_y_idx; |
| 254 | + const int g_output_x_offset = in_x_idx; |
| 255 | + |
| 256 | + // Write output and scales |
| 257 | + if (g_output_y_offset < rows && g_output_x_offset < cols / 2) { |
| 258 | + out[g_output_y_offset * (cols / 2) + g_output_x_offset] = |
| 259 | + static_cast<phi::float8_e4m3fn>(output_scaled_fp32); |
| 260 | + if (x_offset % 128 == 0) { |
| 261 | + // Only one thread per quant block writes the scale |
| 262 | + scales[g_output_y_offset * scale_stride + in_x_idx / 128] = inv_scale; |
| 263 | + } |
| 264 | + } |
| 265 | +} |
| 266 | + |
| 267 | + |
| 268 | +void dispatch_fused_spaq(const paddle::Tensor &X, |
| 269 | + const paddle::optional<paddle::Tensor> &prob, |
| 270 | + paddle::Tensor &out, |
| 271 | + paddle::Tensor &scale, |
| 272 | + const int rows, |
| 273 | + const int cols, |
| 274 | + const bool &using_pow2_scaling, |
| 275 | + const bool &with_prob) { |
| 276 | + constexpr int thread_per_block = 256; |
| 277 | + dim3 grid; |
| 278 | + dim3 block; |
| 279 | + if (cols % 8 == 0) { |
| 280 | + // Use mixed vectorizing strategy, while cols size be 8x (4x2) |
| 281 | + // Each thread process 4 bfloat16 element in same row, each warp handles |
| 282 | + // 1x128 vector Each block handles several sub-row (numel = 4 x blockDim.x) |
| 283 | + // of input vector |
| 284 | + block.x = thread_per_block; |
| 285 | + constexpr int vec_numel = 4; |
| 286 | + const int scale_cols = scale.shape().back(); |
| 287 | + DISPATCH_BOOL( |
| 288 | + using_pow2_scaling, |
| 289 | + k_using_pow2_scaling, |
| 290 | + DISPATCH_BOOL( |
| 291 | + with_prob, k_with_prob, grid.y = rows; |
| 292 | + grid.x = |
| 293 | + ((cols / 2) + block.x * vec_numel - 1) / (block.x * vec_numel); |
| 294 | + LAUNCH_FUSED_SPAQ_VEC4(k_using_pow2_scaling, k_with_prob);)) |
| 295 | + |
| 296 | + } else { |
| 297 | + // Plain elementwise strategy: |
| 298 | + // Each block processing a sub-row (numel = blockDim.x) of the input tensor. |
| 299 | + block.x = thread_per_block; |
| 300 | + DISPATCH_BOOL( |
| 301 | + using_pow2_scaling, |
| 302 | + k_using_pow2_scaling, |
| 303 | + DISPATCH_BOOL(with_prob, k_with_prob, grid.y = rows; |
| 304 | + grid.x = ((cols / 2) + block.x - 1) / block.x; |
| 305 | + LAUNCH_FUSED_SPAQ(k_using_pow2_scaling, k_with_prob);)) |
| 306 | + } |
| 307 | +} |
| 308 | + |
| 309 | + |
| 310 | +std::vector<paddle::Tensor> fused_spaq( |
| 311 | + const paddle::Tensor &X, |
| 312 | + const paddle::optional<paddle::Tensor> &prob, |
| 313 | + const bool &using_pow2_scaling) { |
| 314 | + // ---------------- Arguments check -------------------- |
| 315 | + PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16); |
| 316 | + if (prob) PD_CHECK(prob.get().dtype() == paddle::DataType::FLOAT32); |
| 317 | + int64_t rows = size_to_dim(X.shape().size() - 1, X.shape()); |
| 318 | + int64_t cols = X.shape().back(); |
| 319 | + PADDLE_ENFORCE_EQ(cols % 2, |
| 320 | + 0, |
| 321 | + common::errors::InvalidArgument( |
| 322 | + "The last dim of Input(X) should be exactly divided " |
| 323 | + "by 2 , but got %d", |
| 324 | + cols)); |
| 325 | + if (prob) { |
| 326 | + PADDLE_ENFORCE_EQ(prob.get().shape()[0], |
| 327 | + rows, |
| 328 | + common::errors::InvalidArgument( |
| 329 | + "The first dim of Input(X) should be equal to the " |
| 330 | + "first dim of Input(prob) but got X.shape[0]: %d, " |
| 331 | + "prob.shape[0]: %d", |
| 332 | + rows, |
| 333 | + prob.get().shape()[0])); |
| 334 | + } |
| 335 | + |
| 336 | + paddle::Tensor out; |
| 337 | + paddle::Tensor scale; |
| 338 | + |
| 339 | + out = paddle::empty( |
| 340 | + {rows, cols / 2}, paddle::DataType::FLOAT8_E4M3FN, X.place()); |
| 341 | + scale = paddle::empty( |
| 342 | + {rows, ((cols / 2) + 127) / 128}, paddle::DataType::FLOAT32, X.place()); |
| 343 | + |
| 344 | + dispatch_fused_spaq( |
| 345 | + X, prob, out, scale, rows, cols, using_pow2_scaling, !!prob); |
| 346 | + return {out, scale}; |
| 347 | +} |
| 348 | + |
| 349 | +PD_BUILD_OP(fused_spaq) |
| 350 | + .Inputs({"X", paddle::Optional("prob")}) |
| 351 | + .Outputs({"output", "scale"}) |
| 352 | + .Attrs({"using_pow2_scaling: bool"}) |
| 353 | + .SetKernelFn(PD_KERNEL(fused_spaq)); |
0 commit comments