Skip to content

Commit 6c206e1

Browse files
authored
Add fused_stack_transpose_quant kernel (optional transpose) (#10649)
1 parent a4d90ab commit 6c206e1

File tree

3 files changed

+377
-0
lines changed

3 files changed

+377
-0
lines changed
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
struct FastDiv {
13+
FastDiv() {}
14+
FastDiv(uint64_t d) {
15+
for (shift_val = 0; shift_val < 64; ++shift_val) {
16+
uint64_t shift_limit = uint64_t(1) << shift_val;
17+
if (shift_limit >= d) break;
18+
}
19+
20+
// quotient = ((uint128_t)n_hi << 64) / d
21+
uint64_t quotient = 0;
22+
uint64_t n_hi = (uint64_t(1) << shift_val) - d, n_lo = 0;
23+
for (int i = 63; i >= 0; --i) {
24+
uint64_t d_hi = i == 0 ? 0 : d >> (64 - i);
25+
uint64_t d_lo = d << i;
26+
if (n_hi == 0 && n_lo == 0) break;
27+
if ((d_hi < n_hi) || (d_hi <= n_hi && d_lo <= n_lo)) {
28+
quotient |= uint64_t(1) << i;
29+
n_hi -= d_hi + (d_lo > n_lo);
30+
n_lo -= d_lo;
31+
}
32+
}
33+
multiplier = quotient + 1;
34+
}
35+
36+
__device__ uint64_t Div(uint64_t n) const {
37+
uint64_t t = __umul64hi(n, multiplier);
38+
return (t + n) >> shift_val;
39+
}
40+
41+
int shift_val;
42+
uint64_t multiplier;
43+
};
44+
45+
__device__ void BlockLoad(const int64_t* __restrict__ X_ptrs,
46+
__nv_bfloat16 input[4][4],
47+
size_t K,
48+
size_t block_y,
49+
size_t block_x) {
50+
const __nv_bfloat16* X =
51+
reinterpret_cast<const __nv_bfloat16*>(X_ptrs[blockIdx.z]);
52+
53+
for (size_t i = 0; i < 4; i++) {
54+
size_t idx_m = block_y * 128 + threadIdx.y + i * 32;
55+
size_t idx_k = block_x * 128 + threadIdx.x * 4;
56+
size_t idx = idx_m * K + idx_k;
57+
58+
using LoadT = VecType<__nv_bfloat16, 4>;
59+
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
60+
for (int j = 0; j < 4; j++) {
61+
input[i][j] = data[j];
62+
}
63+
}
64+
}
65+
66+
__device__ __nv_bfloat16 WarpReduceMax(__nv_bfloat16 x) {
67+
for (int offset = 16; offset > 0; offset /= 2) {
68+
__nv_bfloat16 t = __shfl_down_sync(0xffffffff, x, offset);
69+
x = __hmax(x, t);
70+
}
71+
return x;
72+
}
73+
74+
__device__ __nv_bfloat16 BlockReduceMax(__nv_bfloat16 input[4][4]) {
75+
// [(4), 32, 32, (4)] => [32, 32]
76+
__nv_bfloat16 local_max;
77+
for (int i = 0; i < 4; i++) {
78+
for (int j = 0; j < 4; j++) {
79+
__nv_bfloat16 t = __habs(input[i][j]);
80+
local_max = (i == 0 && j == 0) ? t : __hmax(local_max, t);
81+
}
82+
}
83+
84+
// [32, (32)] => [32]
85+
__nv_bfloat16 warp_max = WarpReduceMax(local_max);
86+
87+
// [(32)] => [1]
88+
__shared__ __nv_bfloat16 block_max[32];
89+
if (threadIdx.x == 0) {
90+
block_max[threadIdx.y] = warp_max;
91+
}
92+
__syncthreads();
93+
if (threadIdx.y == 0) {
94+
warp_max = WarpReduceMax(block_max[threadIdx.x]);
95+
if (threadIdx.x == 0) {
96+
block_max[0] = warp_max;
97+
}
98+
}
99+
__syncthreads();
100+
101+
return block_max[0];
102+
}
103+
104+
template <typename OutT>
105+
__global__ void __launch_bounds__(1024)
106+
FusedStackQuantKernel(const int64_t* __restrict__ X_ptrs,
107+
OutT* __restrict__ out,
108+
float* __restrict__ scale,
109+
size_t M,
110+
size_t K,
111+
FastDiv K_div_128) {
112+
size_t block_y = K_div_128.Div(blockIdx.x);
113+
size_t block_x = blockIdx.x - block_y * (K / 128);
114+
115+
// Load 128x128 elements from X
116+
__nv_bfloat16 input[4][4];
117+
BlockLoad(X_ptrs, input, K, block_y, block_x);
118+
119+
// Find the maximum in all elements
120+
__nv_bfloat16 amax = BlockReduceMax(input);
121+
122+
// Compute scale and store back
123+
float scale_inv = ComputeScale<__nv_bfloat16, OutT>(amax, 0.0f);
124+
float scale_out = __frcp_rn(scale_inv);
125+
if (threadIdx.x == 0 && threadIdx.y == 0) {
126+
size_t idx_n = blockIdx.z;
127+
size_t idx_m = block_y;
128+
size_t idx_k = block_x;
129+
size_t idx = (idx_n * (M / 128) + idx_m) * (K / 128) + idx_k;
130+
scale[idx] = scale_out;
131+
}
132+
133+
// Scale X and store to out
134+
for (size_t i = 0; i < 4; i++) {
135+
size_t idx_n = blockIdx.z;
136+
size_t idx_m = block_y * 128 + threadIdx.y + i * 32;
137+
size_t idx_k = block_x * 128 + threadIdx.x * 4;
138+
size_t idx = (idx_n * M + idx_m) * K + idx_k;
139+
140+
using StoreT = VecType<OutT, 4>;
141+
StoreT data;
142+
for (int j = 0; j < 4; j++) {
143+
float input_fp32 = static_cast<float>(input[i][j]);
144+
float output_scaled = input_fp32 * scale_inv;
145+
data[j] = static_cast<OutT>(output_scaled);
146+
}
147+
*reinterpret_cast<StoreT*>(out + idx) = data;
148+
}
149+
}
150+
151+
template <typename OutT>
152+
__global__ void __launch_bounds__(1024)
153+
FusedStackTransposeQuantKernel(const int64_t* __restrict__ X_ptrs,
154+
OutT* __restrict__ out,
155+
float* __restrict__ scale,
156+
size_t M,
157+
size_t K,
158+
FastDiv K_div_128) {
159+
size_t block_y = K_div_128.Div(blockIdx.x);
160+
size_t block_x = blockIdx.x - block_y * (K / 128);
161+
162+
// Load 128x128 elements from X
163+
__nv_bfloat16 input[4][4];
164+
BlockLoad(X_ptrs, input, K, block_y, block_x);
165+
166+
// Find the maximum in all elements
167+
__nv_bfloat16 amax = BlockReduceMax(input);
168+
169+
// Compute scale and store back
170+
float scale_inv = ComputeScale<__nv_bfloat16, OutT>(amax, 0.0f);
171+
float scale_out = __frcp_rn(scale_inv);
172+
if (threadIdx.x == 0 && threadIdx.y == 0) {
173+
size_t idx_n = blockIdx.z;
174+
size_t idx_k = block_x;
175+
size_t idx_m = block_y;
176+
size_t idx = (idx_n * (K / 128) + idx_k) * (M / 128) + idx_m;
177+
scale[idx] = scale_out;
178+
}
179+
180+
// Scale X and transpose in shared memory
181+
__shared__ OutT shm[128][129];
182+
for (int i = 0; i < 4; i++) {
183+
for (int j = 0; j < 4; j++) {
184+
float input_fp32 = static_cast<float>(input[i][j]);
185+
float output_scaled = input_fp32 * scale_inv;
186+
shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] =
187+
static_cast<OutT>(output_scaled);
188+
}
189+
}
190+
__syncthreads();
191+
192+
// Store X back to out
193+
for (size_t i = 0; i < 4; i++) {
194+
size_t idx_n = blockIdx.z;
195+
size_t idx_k = block_x * 128 + threadIdx.y + i * 32;
196+
size_t idx_m = block_y * 128 + threadIdx.x * 4;
197+
size_t idx = (idx_n * K + idx_k) * M + idx_m;
198+
199+
using StoreT = VecType<OutT, 4>;
200+
StoreT data;
201+
for (int j = 0; j < 4; j++) {
202+
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
203+
}
204+
*reinterpret_cast<StoreT*>(out + idx) = data;
205+
}
206+
}
207+
208+
/**
209+
* Stack tensors in X, optionally transpose dim[-1] and dim[-2], and do
210+
* quantization on both dim[-1] and dim[-2].
211+
*
212+
* Inputs:
213+
* X : N tensors of [M, K], bfloat16
214+
*
215+
* Outputs:
216+
* if Transpose:
217+
* out : [N * K, M], float8_e4m3fn
218+
* scale: [N * K / 128, M / 128], float
219+
* else:
220+
* out : [N * M, K], float8_e4m3fn
221+
* scale: [N * M / 128, K / 128], float
222+
*
223+
* Requirements:
224+
* 1) N <= 65535
225+
* 2) M % 128 == 0
226+
* 3) K % 128 == 0
227+
*/
228+
template <bool Transpose>
229+
std::vector<paddle::Tensor> fused_stack_transpose_quant(
230+
const std::vector<paddle::Tensor>& X) {
231+
int64_t N = X.size();
232+
PD_CHECK(N > 0);
233+
for (int64_t i = 0; i < N; i++) {
234+
PD_CHECK(X[i].dtype() == paddle::DataType::BFLOAT16);
235+
}
236+
237+
std::vector<int64_t> shape = X[0].shape();
238+
PD_CHECK(shape.size() == 2);
239+
int64_t M = shape[0];
240+
int64_t K = shape[1];
241+
242+
for (int64_t i = 1; i < N; i++) {
243+
std::vector<int64_t> shape = X[i].shape();
244+
PD_CHECK(shape.size() == 2);
245+
PD_CHECK(shape[0] == M);
246+
PD_CHECK(shape[1] == K);
247+
}
248+
249+
PADDLE_ENFORCE_LE(N,
250+
65535,
251+
common::errors::InvalidArgument(
252+
"The batch size (N) must be no larger than 65535."));
253+
PADDLE_ENFORCE_EQ(M % 128,
254+
0,
255+
common::errors::InvalidArgument(
256+
"The upper dim (M) must be multiple of 128."));
257+
PADDLE_ENFORCE_EQ(K % 128,
258+
0,
259+
common::errors::InvalidArgument(
260+
"The lower dim (K) must be multiple of 128."));
261+
262+
// Allocate for out and scale
263+
std::vector<int64_t> out_shape, scale_shape;
264+
if (Transpose) {
265+
out_shape = {N * K, M};
266+
scale_shape = {N * K / 128, M / 128};
267+
} else {
268+
out_shape = {N * M, K};
269+
scale_shape = {N * M / 128, K / 128};
270+
}
271+
272+
const auto& place = X[0].place();
273+
paddle::Tensor out =
274+
paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, place);
275+
paddle::Tensor scale =
276+
paddle::empty(scale_shape, paddle::DataType::FLOAT32, place);
277+
278+
// Skip 0-size
279+
if (M == 0 || K == 0) {
280+
return {out, scale};
281+
}
282+
283+
// Copy the pointers in X to device
284+
paddle::Tensor X_ptrs_cpu = paddle::empty({N}, paddle::DataType::INT64);
285+
int64_t* X_ptrs_data = X_ptrs_cpu.data<int64_t>();
286+
for (int64_t i = 0; i < N; i++) {
287+
X_ptrs_data[i] = reinterpret_cast<int64_t>(X[i].data());
288+
}
289+
paddle::Tensor X_ptrs_gpu = X_ptrs_cpu.copy_to(place, /* blocking= */ false);
290+
291+
// Launch kernel
292+
dim3 grid((M / 128) * (K / 128), 1, N);
293+
dim3 block(32, 32);
294+
295+
#define LAUNCH_KERNEL(KERNEL) \
296+
KERNEL<<<grid, block>>>(X_ptrs_gpu.data<int64_t>(), \
297+
out.data<phi::float8_e4m3fn>(), \
298+
scale.data<float>(), \
299+
M, \
300+
K, \
301+
FastDiv(K / 128))
302+
if (Transpose) {
303+
LAUNCH_KERNEL(FusedStackTransposeQuantKernel);
304+
} else {
305+
LAUNCH_KERNEL(FusedStackQuantKernel);
306+
}
307+
#undef LAUNCH_KERNEL
308+
309+
return {out, scale};
310+
}
311+
312+
PD_BUILD_OP(fused_stack_quant)
313+
.Inputs({paddle::Vec("X")})
314+
.Outputs({"output", "scale"})
315+
.SetKernelFn(PD_KERNEL(fused_stack_transpose_quant<false>));
316+
317+
PD_BUILD_OP(fused_stack_transpose_quant)
318+
.Inputs({paddle::Vec("X")})
319+
.Outputs({"output", "scale"})
320+
.SetKernelFn(PD_KERNEL(fused_stack_transpose_quant<true>));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setup_fused_quant_ops():
4141
"fused_quanted_ops/fused_act_dequant.cu",
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
44+
"fused_quanted_ops/fused_stack_transpose_quant.cu",
4445
],
4546
extra_compile_args={
4647
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_stack_quant(out, scale):
8+
scale = paddle.repeat_interleave(scale, repeats=128, axis=0)
9+
scale = paddle.repeat_interleave(scale, repeats=128, axis=1)
10+
x = out.astype('float32') * scale
11+
return x
12+
13+
14+
def test_fused_stack_transpose_quant(
15+
num_experts, seq_len, hidden_size, transpose
16+
):
17+
print(num_experts, seq_len, hidden_size, transpose)
18+
19+
x_vec = []
20+
for _ in range(num_experts):
21+
x = paddle.randn([seq_len, hidden_size], dtype='bfloat16')
22+
x = paddle.clip(x, min=-50, max=50)
23+
x_vec.append(x)
24+
25+
if transpose:
26+
out, scale = FQO.fused_stack_transpose_quant(x_vec)
27+
else:
28+
out, scale = FQO.fused_stack_quant(x_vec)
29+
30+
x_fp32 = paddle.stack(x_vec).reshape([-1, hidden_size]).astype('float32')
31+
x_restored = restore_stack_quant(out, scale)
32+
33+
if transpose:
34+
x_restored = (
35+
x_restored.reshape([num_experts, hidden_size, seq_len])
36+
.transpose([0, 2, 1])
37+
.reshape([-1, hidden_size])
38+
)
39+
40+
np.testing.assert_allclose(
41+
x_fp32, x_restored, rtol=0.01, atol=0.2
42+
) # 存在截断误差,atol=0.2,通常在1e-6
43+
44+
45+
def run():
46+
for batch_size in [1, 4]:
47+
for seq_len in [2048, 7168]:
48+
for hidden_size in [128, 4096]:
49+
for transpose in [False, True]:
50+
test_fused_stack_transpose_quant(
51+
batch_size, seq_len, hidden_size, transpose
52+
)
53+
54+
55+
if __name__ == "__main__":
56+
run()

0 commit comments

Comments
 (0)