Skip to content

Commit cb49846

Browse files
committed
fix
2 parents f4bb9e1 + ae560af commit cb49846

File tree

9 files changed

+831
-78
lines changed

9 files changed

+831
-78
lines changed

ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def auto_tuning_with_compilation(m, n, k, num_sms):
171171
return runtime, num_sms, smem_size
172172

173173

174-
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=112) -> None:
174+
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=132) -> None:
175175
"""
176176
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
177177
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.

ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu
9898

9999

100100
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
101-
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=112
101+
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=132
102102
) -> None:
103103
"""
104104
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
@@ -215,7 +215,7 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr
215215

216216

217217
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(
218-
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=112
218+
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=132
219219
) -> None:
220220
"""
221221
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.

ops/csrc/fp8/deep_gemm/jit_kernels/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor:
109109
assert x.dim() in (2, 3)
110110
remove_dim = False
111111
if x.dim() == 2:
112+
m, n = x.shape
113+
114+
aligned_m = get_tma_aligned_size(m, x.element_size())
115+
116+
if aligned_m == m and x.strides[0] == 1 and x.strides[1] == aligned_m:
117+
return x
118+
112119
x, remove_dim = x.unsqueeze(0), True
113120

114121
b, m, n = x.shape
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
#include "quant_utils.h"
2+
3+
#define LAUNCH_FUSED_SPAQ(__using_pow2_scaling, __with_prob) \
4+
do { \
5+
auto kernel = FusedSPAQKernel<__using_pow2_scaling, __with_prob>; \
6+
kernel<<<grid, block, 0, X.stream()>>>( \
7+
X.data<phi::bfloat16>(), \
8+
prob ? prob->data<float>() : nullptr, \
9+
out.data<phi::float8_e4m3fn>(), \
10+
scale.data<float>(), \
11+
rows, \
12+
cols); \
13+
} while (0)
14+
15+
#define LAUNCH_FUSED_SPAQ_VEC4(__using_pow2_scaling, __with_prob) \
16+
do { \
17+
auto kernel = FusedSPAQKernelVec4<__using_pow2_scaling, \
18+
__with_prob, \
19+
thread_per_block>; \
20+
kernel<<<grid, block, 0, X.stream()>>>( \
21+
X.data<phi::bfloat16>(), \
22+
prob ? prob->data<float>() : nullptr, \
23+
out.data<phi::float8_e4m3fn>(), \
24+
scale.data<float>(), \
25+
rows, \
26+
cols,\
27+
scale_cols); \
28+
} while (0)
29+
30+
31+
typedef struct __align__(8) {
32+
__nv_bfloat16 x;
33+
__nv_bfloat16 y;
34+
__nv_bfloat16 z;
35+
__nv_bfloat16 w;
36+
}
37+
bfloat16x4_t;
38+
39+
typedef struct __align__(4) {
40+
__nv_fp8_e4m3 x;
41+
__nv_fp8_e4m3 y;
42+
__nv_fp8_e4m3 z;
43+
__nv_fp8_e4m3 w;
44+
}
45+
fp8_e4m3x4_t;
46+
47+
__device__ __forceinline__ float fast_swiglu(const __nv_bfloat16 x,
48+
const __nv_bfloat16 y) {
49+
const float x_f = __bfloat162float(x);
50+
const float y_f = __bfloat162float(y);
51+
const float silu = x_f * __frcp_rn(1.0f + __expf(-x_f));
52+
const float result = silu * y_f;
53+
return result;
54+
}
55+
__device__ __forceinline__ float4 fast_swiglu_vec4(const bfloat16x4_t &lhs,
56+
const bfloat16x4_t &rhs) {
57+
const float x_f_x = __bfloat162float(lhs.x);
58+
const float x_f_y = __bfloat162float(lhs.y);
59+
const float x_f_z = __bfloat162float(lhs.z);
60+
const float x_f_w = __bfloat162float(lhs.w);
61+
62+
const float y_f_x = __bfloat162float(rhs.x);
63+
const float y_f_y = __bfloat162float(rhs.y);
64+
const float y_f_z = __bfloat162float(rhs.z);
65+
const float y_f_w = __bfloat162float(rhs.w);
66+
67+
const float silu_x = x_f_x * __frcp_rn(1.0f + __expf(-x_f_x));
68+
const float silu_y = x_f_y * __frcp_rn(1.0f + __expf(-x_f_y));
69+
const float silu_z = x_f_z * __frcp_rn(1.0f + __expf(-x_f_z));
70+
const float silu_w = x_f_w * __frcp_rn(1.0f + __expf(-x_f_w));
71+
72+
return {silu_x * y_f_x, silu_y * y_f_y, silu_z * y_f_z, silu_w * y_f_w};
73+
}
74+
__device__ __forceinline__ float amax_float4(const float4 &vec) {
75+
return fmaxf(fmaxf(fabsf(vec.x), fabsf(vec.y)),
76+
fmaxf(fabsf(vec.z), fabsf(vec.w)));
77+
}
78+
79+
__device__ __forceinline__ fp8_e4m3x4_t
80+
scale_fp32x4_to_fp8x4(const float4 &vec, const float scale) {
81+
return {static_cast<__nv_fp8_e4m3>(vec.x * scale),
82+
static_cast<__nv_fp8_e4m3>(vec.y * scale),
83+
static_cast<__nv_fp8_e4m3>(vec.z * scale),
84+
static_cast<__nv_fp8_e4m3>(vec.w * scale)};
85+
}
86+
87+
88+
template <bool using_pow2_scaling, bool with_prob, int thread_per_block>
89+
__global__ void FusedSPAQKernelVec4(const phi::bfloat16 *__restrict__ Xin,
90+
const float *__restrict__ prob,
91+
phi::float8_e4m3fn *__restrict__ out,
92+
float *__restrict__ scales,
93+
const int rows,
94+
const int cols,
95+
const int scale_cols) {
96+
constexpr int elements_per_thread = 4;
97+
constexpr int warp_size = 32;
98+
constexpr int warp_num = thread_per_block / warp_size;
99+
const int scale_stride = scale_cols;
100+
const int lane = threadIdx.x % warp_size;
101+
const int x_offset = threadIdx.x * elements_per_thread;
102+
const int in_y_idx = blockIdx.y;
103+
const int in_x_idx = blockIdx.x * blockDim.x * elements_per_thread + x_offset;
104+
const int src_idx = in_y_idx * cols + in_x_idx;
105+
const unsigned int mask = 0xffffffff; // whole warp mask
106+
float p_t0;
107+
if (in_x_idx >= cols / 2 || in_y_idx > rows) [[unlikely]]
108+
return;
109+
110+
if constexpr(with_prob){
111+
// Prefetch prob
112+
if(lane==0) p_t0 = prob[in_y_idx];
113+
}
114+
115+
const __nv_bfloat16 *X = reinterpret_cast<const __nv_bfloat16 *>(Xin);
116+
117+
// Initialize activation storage
118+
float4 act_f32x4;
119+
bfloat16x4_t lhs_bf16x4, rhs_bf16x4;
120+
121+
// Reinterpret input pointer as bfloat16x4_t* for vectorized loading
122+
const bfloat16x4_t *X_lhs_vec =
123+
reinterpret_cast<const bfloat16x4_t *>(X + src_idx);
124+
const bfloat16x4_t *X_rhs_vec =
125+
reinterpret_cast<const bfloat16x4_t *>(X + src_idx + cols / 2);
126+
127+
lhs_bf16x4 = *X_lhs_vec;
128+
rhs_bf16x4 = *X_rhs_vec;
129+
130+
act_f32x4 = fast_swiglu_vec4(lhs_bf16x4, rhs_bf16x4);
131+
132+
if constexpr (with_prob) {
133+
// Warp level sync to avoid syncthreads
134+
const float p = __shfl_sync(mask, p_t0, 0);
135+
act_f32x4.x *= p;
136+
act_f32x4.y *= p;
137+
act_f32x4.z *= p;
138+
act_f32x4.w *= p;
139+
}
140+
141+
// Phase 2: Block Reduction to find per-quant block absolute maxima
142+
// Compute absolute values
143+
float thread_amax = amax_float4(act_f32x4);
144+
145+
// All-Reduce within the warp
146+
#pragma unroll
147+
for (int offset = 16; offset > 0; offset /= 2) {
148+
const float val = __shfl_down_sync(mask, thread_amax, offset);
149+
thread_amax = fmaxf(thread_amax, val);
150+
}
151+
const float final_amax = __shfl_sync(mask, thread_amax, 0);
152+
153+
// Phase 3: Compute scales and quantize the outputs
154+
const float scale =
155+
ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>(final_amax, 0.0f);
156+
const float inv_scale = __frcp_rn(scale);
157+
158+
const fp8_e4m3x4_t act_fp8x4 = scale_fp32x4_to_fp8x4(act_f32x4, scale);
159+
fp8_e4m3x4_t *const out_vec_addr =
160+
reinterpret_cast<fp8_e4m3x4_t *>(out + in_y_idx * cols / 2 + in_x_idx);
161+
*out_vec_addr = act_fp8x4;
162+
if (lane == 0) scales[in_y_idx * scale_stride + in_x_idx / 128] = inv_scale;
163+
}
164+
165+
template <bool using_pow2_scaling, bool with_prob>
166+
__global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
167+
const float *__restrict__ prob,
168+
phi::float8_e4m3fn *__restrict__ out,
169+
float *__restrict__ scales,
170+
const int rows,
171+
const int cols) {
172+
// Configure shared memory
173+
__shared__ float smem_tile[256]; // Shared memory for activation values
174+
__shared__ float warp_max[2][4]; // Shared memory for warp maxima (2 quant
175+
// blocks x 4 warps)
176+
__shared__ __nv_bfloat16
177+
quant_block_amax[2]; // Shared memory for quant block maxima
178+
179+
const __nv_bfloat16 *X = reinterpret_cast<const __nv_bfloat16 *>(Xin);
180+
const int x_offset = threadIdx.x;
181+
const int quant_block_idx =
182+
threadIdx.x / 128; // 0 or 1, two quant blocks per block
183+
const int in_y_idx = blockIdx.y;
184+
const int in_x_idx = blockIdx.x * blockDim.x + x_offset;
185+
const int src_idx = in_y_idx * cols + in_x_idx;
186+
187+
// Load data and compute swiGLU activation
188+
if (in_x_idx < cols / 2) [[likely]] {
189+
__nv_bfloat16 x1 = X[src_idx]; // First half of the input
190+
__nv_bfloat16 x2 = X[src_idx + cols / 2]; // Second half of the input
191+
192+
if constexpr (with_prob) {
193+
float row_prob = prob[in_y_idx];
194+
smem_tile[x_offset] = fast_swiglu(x1, x2) * row_prob;
195+
} else {
196+
smem_tile[x_offset] = fast_swiglu(x1, x2);
197+
}
198+
}
199+
200+
__syncthreads(); // Ensure all threads have loaded their data
201+
202+
// Phase 2: Block Reduction to find per-quant block absolute maximums
203+
float local_max = (in_x_idx < (cols / 2)) ? fabsf(smem_tile[x_offset]) : 0.0f;
204+
205+
206+
// Warp-level reduction
207+
unsigned int mask = 0xffffffff;
208+
int lane = threadIdx.x % 32;
209+
int warp_id =
210+
(threadIdx.x % 128) / 32; // Warp ID within the quant block (0-3)
211+
212+
// Reduce within the warp
213+
for (int offset = 16; offset > 0; offset /= 2) {
214+
float val = __shfl_down_sync(mask, local_max, offset);
215+
local_max = fmaxf(local_max, val);
216+
}
217+
218+
// Store warp maxima
219+
if (lane == 0) {
220+
warp_max[quant_block_idx][warp_id] = local_max;
221+
}
222+
223+
__syncthreads();
224+
225+
// Reduce warp maxima to get quant block maxima
226+
if (warp_id == 0 && lane < 4) {
227+
if (threadIdx.x < 256) { // Ensure only valid threads participate
228+
float block_max = warp_max[quant_block_idx][lane];
229+
// Reduce over the 4 warp maxima
230+
if (lane == 0) {
231+
block_max = fmaxf(block_max, warp_max[quant_block_idx][1]);
232+
block_max = fmaxf(block_max, warp_max[quant_block_idx][2]);
233+
block_max = fmaxf(block_max, warp_max[quant_block_idx][3]);
234+
quant_block_amax[quant_block_idx] = __float2bfloat16(block_max);
235+
}
236+
}
237+
}
238+
239+
__syncthreads();
240+
241+
// Phase 3: Compute scales and quantize the outputs
242+
const float block_max_float = (float)quant_block_amax[quant_block_idx];
243+
const int scale_stride = (cols / 2 + 127) / 128;
244+
245+
float scale = ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>(
246+
block_max_float, 0.0f);
247+
float inv_scale = __frcp_rn(scale);
248+
249+
// Quantize
250+
float output_scaled_fp32 = smem_tile[x_offset] * scale;
251+
252+
253+
const int g_output_y_offset = in_y_idx;
254+
const int g_output_x_offset = in_x_idx;
255+
256+
// Write output and scales
257+
if (g_output_y_offset < rows && g_output_x_offset < cols / 2) {
258+
out[g_output_y_offset * (cols / 2) + g_output_x_offset] =
259+
static_cast<phi::float8_e4m3fn>(output_scaled_fp32);
260+
if (x_offset % 128 == 0) {
261+
// Only one thread per quant block writes the scale
262+
scales[g_output_y_offset * scale_stride + in_x_idx / 128] = inv_scale;
263+
}
264+
}
265+
}
266+
267+
268+
void dispatch_fused_spaq(const paddle::Tensor &X,
269+
const paddle::optional<paddle::Tensor> &prob,
270+
paddle::Tensor &out,
271+
paddle::Tensor &scale,
272+
const int rows,
273+
const int cols,
274+
const bool &using_pow2_scaling,
275+
const bool &with_prob) {
276+
constexpr int thread_per_block = 256;
277+
dim3 grid;
278+
dim3 block;
279+
if (cols % 8 == 0) {
280+
// Use mixed vectorizing strategy, while cols size be 8x (4x2)
281+
// Each thread process 4 bfloat16 element in same row, each warp handles
282+
// 1x128 vector Each block handles several sub-row (numel = 4 x blockDim.x)
283+
// of input vector
284+
block.x = thread_per_block;
285+
constexpr int vec_numel = 4;
286+
const int scale_cols = scale.shape().back();
287+
DISPATCH_BOOL(
288+
using_pow2_scaling,
289+
k_using_pow2_scaling,
290+
DISPATCH_BOOL(
291+
with_prob, k_with_prob, grid.y = rows;
292+
grid.x =
293+
((cols / 2) + block.x * vec_numel - 1) / (block.x * vec_numel);
294+
LAUNCH_FUSED_SPAQ_VEC4(k_using_pow2_scaling, k_with_prob);))
295+
296+
} else {
297+
// Plain elementwise strategy:
298+
// Each block processing a sub-row (numel = blockDim.x) of the input tensor.
299+
block.x = thread_per_block;
300+
DISPATCH_BOOL(
301+
using_pow2_scaling,
302+
k_using_pow2_scaling,
303+
DISPATCH_BOOL(with_prob, k_with_prob, grid.y = rows;
304+
grid.x = ((cols / 2) + block.x - 1) / block.x;
305+
LAUNCH_FUSED_SPAQ(k_using_pow2_scaling, k_with_prob);))
306+
}
307+
}
308+
309+
310+
std::vector<paddle::Tensor> fused_spaq(
311+
const paddle::Tensor &X,
312+
const paddle::optional<paddle::Tensor> &prob,
313+
const bool &using_pow2_scaling) {
314+
// ---------------- Arguments check --------------------
315+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
316+
if (prob) PD_CHECK(prob.get().dtype() == paddle::DataType::FLOAT32);
317+
int64_t rows = size_to_dim(X.shape().size() - 1, X.shape());
318+
int64_t cols = X.shape().back();
319+
PADDLE_ENFORCE_EQ(cols % 2,
320+
0,
321+
common::errors::InvalidArgument(
322+
"The last dim of Input(X) should be exactly divided "
323+
"by 2 , but got %d",
324+
cols));
325+
if (prob) {
326+
PADDLE_ENFORCE_EQ(prob.get().shape()[0],
327+
rows,
328+
common::errors::InvalidArgument(
329+
"The first dim of Input(X) should be equal to the "
330+
"first dim of Input(prob) but got X.shape[0]: %d, "
331+
"prob.shape[0]: %d",
332+
rows,
333+
prob.get().shape()[0]));
334+
}
335+
336+
paddle::Tensor out;
337+
paddle::Tensor scale;
338+
339+
out = paddle::empty(
340+
{rows, cols / 2}, paddle::DataType::FLOAT8_E4M3FN, X.place());
341+
scale = paddle::empty(
342+
{rows, ((cols / 2) + 127) / 128}, paddle::DataType::FLOAT32, X.place());
343+
344+
dispatch_fused_spaq(
345+
X, prob, out, scale, rows, cols, using_pow2_scaling, !!prob);
346+
return {out, scale};
347+
}
348+
349+
PD_BUILD_OP(fused_spaq)
350+
.Inputs({"X", paddle::Optional("prob")})
351+
.Outputs({"output", "scale"})
352+
.Attrs({"using_pow2_scaling: bool"})
353+
.SetKernelFn(PD_KERNEL(fused_spaq));

0 commit comments

Comments
 (0)