Skip to content

Commit 452bb11

Browse files
committed
Add fused_transpose_split_quant kernel
1 parent 6c206e1 commit 452bb11

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
template <int VecSize>
13+
__device__ void BlockLoad(const phi::bfloat16* X,
14+
__nv_bfloat16 input[4][4],
15+
size_t K) {
16+
for (size_t i = 0; i < 4; i++) {
17+
size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32;
18+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
19+
size_t offset = off_m * K + off_k;
20+
21+
for (size_t j = 0; j < 4; j += VecSize) {
22+
if (off_k + j * 32 < K) {
23+
size_t idx = offset + j * 32;
24+
using LoadT = VecType<__nv_bfloat16, VecSize>;
25+
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
26+
for (int k = 0; k < VecSize; k++) {
27+
input[i][j + k] = data[k];
28+
}
29+
}
30+
}
31+
}
32+
}
33+
34+
__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
35+
__nv_bfloat16 amax[4],
36+
__nv_bfloat16* shm) {
37+
// Reduce [(4), 32, 32, 4] => [32, 32, 4]
38+
__nv_bfloat16 warp_max[4];
39+
for (int i = 0; i < 4; i++) {
40+
for (int j = 0; j < 4; j++) {
41+
__nv_bfloat16 t = __habs(input[i][j]);
42+
warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t);
43+
}
44+
}
45+
46+
// Reduce [(32), 32, 4] => [32, 4]
47+
for (int i = 0; i < 4; i++) {
48+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
49+
}
50+
__syncthreads();
51+
for (int offset = 16; offset > 0; offset /= 2) {
52+
if (threadIdx.y < offset) {
53+
for (int i = 0; i < 4; i++) {
54+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
55+
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
56+
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
57+
}
58+
}
59+
__syncthreads();
60+
}
61+
62+
for (int i = 0; i < 4; i++) {
63+
amax[i] = shm[i * 32 + threadIdx.x];
64+
}
65+
}
66+
67+
template <typename OutT, int VecSize>
68+
__device__ void BlockStoreScale(float* scale,
69+
__nv_bfloat16 amax[4],
70+
float scale_inv[4],
71+
size_t K) {
72+
float scale_out[4];
73+
for (int i = 0; i < 4; i++) {
74+
float amax_fp32 = static_cast<float>(amax[i]);
75+
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, true>(amax_fp32, 0.0f);
76+
scale_out[i] = __frcp_rn(scale_inv[i]);
77+
}
78+
if (threadIdx.y == 0) {
79+
size_t off_m = blockIdx.x;
80+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
81+
size_t offset = off_m * K + off_k;
82+
83+
for (size_t j = 0; j < 4; j += VecSize) {
84+
if (off_k + j * 32 < K) {
85+
size_t idx = offset + j * 32;
86+
using StoreT = VecType<float, VecSize>;
87+
StoreT data;
88+
for (int k = 0; k < VecSize; k++) {
89+
data[k] = scale_out[j + k];
90+
}
91+
*reinterpret_cast<StoreT*>(scale + idx) = data;
92+
}
93+
}
94+
}
95+
}
96+
97+
template <typename OutT, int VecSize>
98+
__device__ void BlockStoreOut(OutT* out,
99+
const OutT shm[128][129],
100+
const int64_t* __restrict__ tokens_per_expert,
101+
size_t num_experts,
102+
size_t K) {
103+
// Find the current expert_idx
104+
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
105+
size_t expert_idx = 0;
106+
size_t tokens_offset = 0;
107+
size_t next_tokens_offset = 0;
108+
for (; expert_idx < num_experts; expert_idx++) {
109+
next_tokens_offset += tokens_per_expert[expert_idx];
110+
if (idx_m >= tokens_offset && idx_m < next_tokens_offset) {
111+
break;
112+
}
113+
tokens_offset = next_tokens_offset;
114+
}
115+
116+
for (size_t i = 0; i < 4; i++) {
117+
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32;
118+
size_t idx = tokens_offset * K + (idx_m - tokens_offset) +
119+
idx_k * tokens_per_expert[expert_idx];
120+
121+
if (idx_k < K) {
122+
using StoreT = VecType<OutT, VecSize>;
123+
StoreT data;
124+
for (int j = 0; j < VecSize; j++) {
125+
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
126+
}
127+
*reinterpret_cast<StoreT*>(out + idx) = data;
128+
}
129+
}
130+
}
131+
132+
template <typename OutT, int VecSize>
133+
__global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
134+
const phi::bfloat16* __restrict__ X,
135+
OutT* __restrict__ out,
136+
float* __restrict__ scale,
137+
const int64_t* __restrict__ tokens_per_expert,
138+
size_t num_experts,
139+
size_t K) {
140+
__shared__ OutT shm[128][129];
141+
142+
// Load 128x128 elements from X
143+
__nv_bfloat16 input[4][4];
144+
BlockLoad<VecSize>(X, input, K);
145+
146+
// Find the maximum of each 128 elements on the M axis
147+
__nv_bfloat16 amax[4];
148+
BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm));
149+
150+
// Compute scale and scale_inv, then store scale back
151+
float scale_inv[4];
152+
BlockStoreScale<OutT, VecSize>(scale, amax, scale_inv, K);
153+
154+
// Scale X and save into shared memory with transposed layout
155+
for (int i = 0; i < 4; i++) {
156+
for (int j = 0; j < 4; j += VecSize) {
157+
for (int k = 0; k < VecSize; k++) {
158+
float input_fp32 = static_cast<float>(input[i][j + k]);
159+
float output_scaled = input_fp32 * scale_inv[j + k];
160+
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
161+
static_cast<OutT>(output_scaled);
162+
}
163+
}
164+
}
165+
__syncthreads();
166+
167+
// Store 128x128 elements back
168+
// Note: out is always 4x vectorizable.
169+
BlockStoreOut<OutT, 4>(out, shm, tokens_per_expert, num_experts, K);
170+
}
171+
172+
/**
173+
* Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then
174+
* split the output and scale due to tokens_per_expert.
175+
*
176+
* Inputs:
177+
* X : [SUM(M_1...M_N), K], bfloat16
178+
* tokens_per_expert
179+
* : python list of value [M_1, M_2, ..., M_N]
180+
*
181+
* Outputs:
182+
* out : [K * M_1 + K * M_2 + ... + K * M_N]
183+
* scale : [SUM(M_1...M_N)/128, K]
184+
*
185+
* Requirements:
186+
* 1) M_i % 128 == 0 for each M_i in tokens_per_expert
187+
* 2) K <= 65535 * 128
188+
*/
189+
std::vector<paddle::Tensor> fused_transpose_split_quant(
190+
const paddle::Tensor& X, const std::vector<int64_t>& tokens_per_expert) {
191+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
192+
193+
std::vector<int64_t> shape = X.shape();
194+
PD_CHECK(shape.size() == 2);
195+
const int64_t M = shape[0];
196+
const int64_t K = shape[1];
197+
198+
int64_t sum_tokens = 0;
199+
for (size_t i = 0; i < tokens_per_expert.size(); i++) {
200+
PADDLE_ENFORCE_EQ(tokens_per_expert[i] % 128,
201+
0,
202+
common::errors::InvalidArgument(
203+
"Each tokens_per_expert must be multiple of 128. "
204+
"However, got tokens_per_expert[%d] = %lld.",
205+
i,
206+
tokens_per_expert[i]));
207+
sum_tokens += tokens_per_expert[i];
208+
}
209+
210+
PADDLE_ENFORCE_EQ(
211+
sum_tokens,
212+
M,
213+
common::errors::InvalidArgument(
214+
"Sum of tokens_per_expert must be equal to X.shape[0]."));
215+
PADDLE_ENFORCE_LE(K,
216+
65535 * 128,
217+
common::errors::InvalidArgument(
218+
"X.shape[1] must be no larger than 65535 * 128."));
219+
220+
// Allocate for out and scale
221+
paddle::Tensor out =
222+
paddle::empty({K * M}, paddle::DataType::FLOAT8_E4M3FN, X.place());
223+
paddle::Tensor scale =
224+
paddle::empty({M / 128, K}, paddle::DataType::FLOAT32, X.place());
225+
226+
// Skip 0-size
227+
if (M == 0 || K == 0) {
228+
return {out, scale};
229+
}
230+
231+
// Copy tokens_per_expert to device
232+
paddle::Tensor tokens_per_expert_cpu =
233+
paddle::empty({static_cast<int64_t>(tokens_per_expert.size())},
234+
paddle::DataType::INT64);
235+
std::memcpy(tokens_per_expert_cpu.data(),
236+
tokens_per_expert.data(),
237+
sizeof(int64_t) * tokens_per_expert.size());
238+
paddle::Tensor tokens_per_expert_gpu =
239+
tokens_per_expert_cpu.copy_to(X.place(), /* blocking= */ false);
240+
241+
// Launch kernel
242+
dim3 grid(M / 128, (K + 127) / 128);
243+
dim3 block(32, 32);
244+
245+
#define LAUNCH_KERNEL(VEC_SIZE) \
246+
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, VEC_SIZE> \
247+
<<<grid, block>>>(X.data<phi::bfloat16>(), \
248+
out.data<phi::float8_e4m3fn>(), \
249+
scale.data<float>(), \
250+
tokens_per_expert_gpu.data<int64_t>(), \
251+
tokens_per_expert.size(), \
252+
K);
253+
if (K % 4 == 0) {
254+
LAUNCH_KERNEL(4);
255+
} else if (K % 2 == 0) {
256+
LAUNCH_KERNEL(2);
257+
} else {
258+
LAUNCH_KERNEL(1);
259+
}
260+
#undef LAUNCH_KERNEL
261+
262+
return {out, scale};
263+
}
264+
265+
PD_BUILD_OP(fused_transpose_split_quant)
266+
.Inputs({"X"})
267+
.Outputs({"output", "scale"})
268+
.Attrs({"tokens_per_expert: std::vector<int64_t>"})
269+
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def setup_fused_quant_ops():
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
4444
"fused_quanted_ops/fused_stack_transpose_quant.cu",
45+
"fused_quanted_ops/fused_transpose_split_quant.cu",
4546
],
4647
extra_compile_args={
4748
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_transpose_split_quant(out, scale):
8+
out = [t.astype('float32') for t in out]
9+
out = paddle.concat(out, axis=1).transpose([1, 0])
10+
scale = paddle.concat(scale, axis=0)
11+
scale = paddle.repeat_interleave(scale, repeats=128, axis=0)
12+
return out * scale
13+
14+
15+
def run():
16+
tokens_per_expert = [24*128, 50*128, 1*128, 128*128, 13*128]
17+
18+
for seq_len in [1, 127, 2562, 4001, 7168]:
19+
print(tokens_per_expert, seq_len)
20+
21+
x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
22+
x = paddle.clip(x, min=-50, max=50)
23+
24+
out_raw, scale_raw = FQO.fused_transpose_split_quant(
25+
x, tokens_per_expert
26+
)
27+
28+
out, scale = [], []
29+
token_offset = 0
30+
for tokens in tokens_per_expert:
31+
out_offset = seq_len * token_offset
32+
out_size = seq_len * tokens
33+
out.append(
34+
out_raw[out_offset : out_offset + out_size]
35+
.reshape([seq_len, tokens])
36+
)
37+
scale.append(
38+
scale_raw[token_offset // 128 : (token_offset + tokens) // 128]
39+
)
40+
token_offset += tokens
41+
42+
x_restore = restore_transpose_split_quant(out, scale)
43+
x_cast = x.astype('float32')
44+
45+
np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)
46+
47+
48+
if __name__ == '__main__':
49+
run()

0 commit comments

Comments
 (0)