Skip to content

Commit ace25e1

Browse files
committed
Add fused_transpose_split_quant kernel
1 parent 6c206e1 commit ace25e1

File tree

3 files changed

+337
-0
lines changed

3 files changed

+337
-0
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
template <int VecSize>
13+
__device__ void BlockLoad(const phi::bfloat16* X,
14+
__nv_bfloat16 input[4][4],
15+
size_t K) {
16+
for (size_t i = 0; i < 4; i++) {
17+
size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32;
18+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
19+
size_t offset = off_m * K + off_k;
20+
21+
for (size_t j = 0; j < 4; j += VecSize) {
22+
if (off_k + j * 32 < K) {
23+
size_t idx = offset + j * 32;
24+
using LoadT = VecType<__nv_bfloat16, VecSize>;
25+
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
26+
for (int k = 0; k < VecSize; k++) {
27+
input[i][j + k] = data[k];
28+
}
29+
}
30+
}
31+
}
32+
}
33+
34+
__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
35+
__nv_bfloat16 amax[4],
36+
__nv_bfloat16* shm) {
37+
// Reduce [(4), 32, 32, 4] => [32, 32, 4]
38+
__nv_bfloat16 warp_max[4];
39+
for (int i = 0; i < 4; i++) {
40+
for (int j = 0; j < 4; j++) {
41+
__nv_bfloat16 t = __habs(input[i][j]);
42+
warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t);
43+
}
44+
}
45+
46+
// Reduce [(32), 32, 4] => [32, 4]
47+
for (int i = 0; i < 4; i++) {
48+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
49+
}
50+
__syncthreads();
51+
for (int offset = 16; offset > 0; offset /= 2) {
52+
if (threadIdx.y < offset) {
53+
for (int i = 0; i < 4; i++) {
54+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
55+
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
56+
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
57+
}
58+
}
59+
__syncthreads();
60+
}
61+
62+
for (int i = 0; i < 4; i++) {
63+
amax[i] = shm[i * 32 + threadIdx.x];
64+
}
65+
}
66+
67+
template <typename OutT, bool Pow2Scales, int VecSize>
68+
__device__ void BlockStoreScale(float* scale,
69+
size_t off_m,
70+
__nv_bfloat16 amax[4],
71+
float scale_inv[4],
72+
size_t K) {
73+
float scale_out[4];
74+
for (int i = 0; i < 4; i++) {
75+
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>(
76+
static_cast<float>(amax[i]), 0.0f);
77+
scale_out[i] = __frcp_rn(scale_inv[i]);
78+
}
79+
if (threadIdx.y == 0) {
80+
size_t idx_m = blockIdx.x - off_m / 128;
81+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
82+
size_t offset = idx_m * K + off_k;
83+
84+
for (size_t j = 0; j < 4; j += VecSize) {
85+
if (off_k + j * 32 < K) {
86+
size_t idx = offset + j * 32;
87+
using StoreT = VecType<float, VecSize>;
88+
StoreT data;
89+
for (int k = 0; k < VecSize; k++) {
90+
data[k] = scale_out[j + k];
91+
}
92+
*reinterpret_cast<StoreT*>(scale + idx) = data;
93+
}
94+
}
95+
}
96+
}
97+
98+
template <typename OutT, int VecSize>
99+
__device__ void BlockStoreOut(OutT* out,
100+
size_t off_m,
101+
size_t cur_tokens,
102+
const OutT shm[128][129],
103+
size_t K) {
104+
for (size_t i = 0; i < 4; i++) {
105+
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
106+
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32;
107+
size_t idx = idx_k * cur_tokens + (idx_m - off_m);
108+
109+
if (idx_k < K) {
110+
using StoreT = VecType<OutT, VecSize>;
111+
StoreT data;
112+
for (int j = 0; j < VecSize; j++) {
113+
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
114+
}
115+
*reinterpret_cast<StoreT*>(out + idx) = data;
116+
}
117+
}
118+
}
119+
120+
template <typename OutT, bool Pow2Scales, int VecSize>
121+
__global__ void __launch_bounds__(1024)
122+
FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ X,
123+
int64_t* __restrict__ meta,
124+
size_t num_experts,
125+
size_t K) {
126+
__shared__ OutT shm[128][129];
127+
int64_t* tokens_per_expert = meta;
128+
OutT** out_ptrs = reinterpret_cast<OutT**>(meta + num_experts);
129+
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);
130+
131+
// Get expert_idx and offset at the M dim of the current block
132+
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
133+
size_t off_m = 0, next_off_m = 0;
134+
size_t expert_idx;
135+
for (expert_idx = 0; expert_idx < num_experts; expert_idx++) {
136+
next_off_m += tokens_per_expert[expert_idx];
137+
if (idx_m >= off_m && idx_m < next_off_m) {
138+
break;
139+
}
140+
off_m = next_off_m;
141+
}
142+
143+
// Load 128x128 elements from X
144+
__nv_bfloat16 input[4][4];
145+
BlockLoad<VecSize>(X, input, K);
146+
147+
// Find the maximum of each 128 elements on the M axis
148+
__nv_bfloat16 amax[4];
149+
BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm));
150+
151+
// Compute scale and scale_inv, then store scale back
152+
float scale_inv[4];
153+
BlockStoreScale<OutT, Pow2Scales, VecSize>(
154+
scale_ptrs[expert_idx], off_m, amax, scale_inv, K);
155+
156+
// Scale X and save into shared memory with transposed layout
157+
for (int i = 0; i < 4; i++) {
158+
for (int j = 0; j < 4; j += VecSize) {
159+
for (int k = 0; k < VecSize; k++) {
160+
float input_fp32 = static_cast<float>(input[i][j + k]);
161+
float output_scaled = input_fp32 * scale_inv[j + k];
162+
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
163+
static_cast<OutT>(output_scaled);
164+
}
165+
}
166+
}
167+
__syncthreads();
168+
169+
// Store 128x128 elements back
170+
// Note: out is always 4x vectorizable.
171+
BlockStoreOut<OutT, 4>(
172+
out_ptrs[expert_idx], off_m, tokens_per_expert[expert_idx], shm, K);
173+
}
174+
175+
/**
176+
* Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then
177+
* split the result into out and scale.
178+
*
179+
* Inputs:
180+
* X : [SUM(M_1...M_N), K], bfloat16
181+
*
182+
* Outputs:
183+
* out : {[K, M_1], [K, M_2], ..., [K, M_N]}, float8_e4m3fn
184+
* scale : {[M_1/128, K], [M_2/128, K], ..., [M_N/128, K]}, float
185+
*
186+
* Attrs:
187+
* pow_2_scales
188+
* : bool that indicates whether to use power-of-2 scaling
189+
*
190+
* Requirements:
191+
* 1) M_i % 128 == 0 for M_i in [M_1, M_2, ..., M_N]
192+
* 2) K <= 65535 * 128
193+
*/
194+
void fused_transpose_split_quant(const paddle::Tensor& X,
195+
std::vector<paddle::Tensor>& outs,
196+
std::vector<paddle::Tensor>& scales,
197+
bool pow_2_scales) {
198+
// Check X
199+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
200+
201+
std::vector<int64_t> shape = X.shape();
202+
PD_CHECK(shape.size() == 2);
203+
const int64_t M = shape[0];
204+
const int64_t K = shape[1];
205+
206+
// Check outs and scales
207+
const size_t num_experts = outs.size();
208+
PD_CHECK(scales.size() == num_experts);
209+
210+
std::vector<int64_t> tokens_per_expert;
211+
int64_t sum_tokens = 0;
212+
for (size_t i = 0; i < num_experts; i++) {
213+
PD_CHECK(outs[i].dtype() == paddle::DataType::FLOAT8_E4M3FN);
214+
PD_CHECK(scales[i].dtype() == paddle::DataType::FLOAT32);
215+
216+
std::vector<int64_t> out_shape = outs[i].shape();
217+
PD_CHECK(out_shape.size() == 2);
218+
PD_CHECK(out_shape[0] == K);
219+
PD_CHECK(out_shape[1] % 128 == 0);
220+
tokens_per_expert.push_back(out_shape[1]);
221+
sum_tokens += out_shape[1];
222+
223+
std::vector<int64_t> scale_shape = scales[i].shape();
224+
PD_CHECK(scale_shape.size() == 2);
225+
PD_CHECK(scale_shape[0] == out_shape[1] / 128);
226+
PD_CHECK(scale_shape[1] == K);
227+
}
228+
229+
PD_CHECK(sum_tokens == M,
230+
"sum of out[i].shape[1] must be equal to X.shape[0]");
231+
PD_CHECK(K <= 65535 * 128, "only supports K <= 65535 * 128");
232+
233+
// Skip 0-size
234+
if (M == 0 || K == 0) {
235+
return;
236+
}
237+
238+
// Copy meta (tokens_per_expert, out_ptrs, scale_ptrs) to device
239+
paddle::Tensor meta_cpu = paddle::empty(
240+
{static_cast<int64_t>(num_experts * 3)}, paddle::DataType::INT64);
241+
int64_t* meta_ptr = meta_cpu.data<int64_t>();
242+
for (size_t i = 0; i < num_experts; i++) {
243+
meta_ptr[i] = static_cast<int64_t>(tokens_per_expert[i]);
244+
}
245+
for (size_t i = 0; i < num_experts; i++) {
246+
meta_ptr[num_experts + i] =
247+
reinterpret_cast<int64_t>(outs[i].data<phi::float8_e4m3fn>());
248+
}
249+
for (size_t i = 0; i < num_experts; i++) {
250+
meta_ptr[num_experts * 2 + i] =
251+
reinterpret_cast<int64_t>(scales[i].data<float>());
252+
}
253+
paddle::Tensor meta_gpu = meta_cpu.copy_to(X.place(), /*blocking=*/false);
254+
255+
// Launch kernel
256+
dim3 grid(M / 128, (K + 127) / 128);
257+
dim3 block(32, 32);
258+
259+
#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
260+
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE> \
261+
<<<grid, block>>>( \
262+
X.data<phi::bfloat16>(), meta_gpu.data<int64_t>(), num_experts, K);
263+
#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \
264+
if (pow_2_scales) { \
265+
LAUNCH_KERNEL(true, VEC_SIZE); \
266+
} else { \
267+
LAUNCH_KERNEL(false, VEC_SIZE); \
268+
}
269+
270+
if (K % 4 == 0) {
271+
LAUNCH_KERNEL_PARTIAL(4);
272+
} else if (K % 2 == 0) {
273+
LAUNCH_KERNEL_PARTIAL(2);
274+
} else {
275+
LAUNCH_KERNEL_PARTIAL(1);
276+
}
277+
#undef LAUNCH_KERNEL_PARTIAL
278+
#undef LAUNCH_KERNEL
279+
}
280+
281+
PD_BUILD_OP(fused_transpose_split_quant)
282+
.Inputs({"X", paddle::Vec("outs"), paddle::Vec("scales")})
283+
.Attrs({"pow_2_scales: bool"})
284+
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def setup_fused_quant_ops():
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
4444
"fused_quanted_ops/fused_stack_transpose_quant.cu",
45+
"fused_quanted_ops/fused_transpose_split_quant.cu",
4546
],
4647
extra_compile_args={
4748
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_transpose_split_quant(out, scale):
8+
out = [t.astype('float32') for t in out]
9+
out = paddle.concat(out, axis=1).transpose([1, 0])
10+
scale = paddle.concat(scale, axis=0)
11+
scale = paddle.repeat_interleave(scale, repeats=128, axis=0)
12+
return out * scale
13+
14+
15+
def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales):
16+
print(tokens_per_expert, seq_len, pow_2_scales)
17+
18+
x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
19+
x = paddle.clip(x, min=-50, max=50)
20+
21+
out, scale = [], []
22+
for tokens in tokens_per_expert:
23+
out.append(paddle.empty([seq_len, tokens], dtype='float8_e4m3fn'))
24+
scale.append(paddle.empty([tokens//128, seq_len], dtype='float32'))
25+
26+
FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales)
27+
28+
x_restore = restore_transpose_split_quant(out, scale)
29+
x_cast = x.astype('float32')
30+
31+
np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)
32+
33+
34+
def run():
35+
test_fused_transpose_split_quant([0, 0], 1024, False)
36+
test_fused_transpose_split_quant([128, 2*128], 0, True)
37+
test_fused_transpose_split_quant([128], 1, False)
38+
test_fused_transpose_split_quant([0, 128, 0, 2*128], 127, True)
39+
test_fused_transpose_split_quant([3*128, 4*128, 5*128], 233, False)
40+
test_fused_transpose_split_quant(
41+
[24*128, 128, 50*128, 16*128], 2162, True
42+
)
43+
test_fused_transpose_split_quant(
44+
[7*128, 29*128, 3*128, 128*128, 13*128], 4000, False
45+
)
46+
test_fused_transpose_split_quant(
47+
[18*128, 5*128, 24*128, 128, 6*128, 0, 27*128, 7*128], 7168, True
48+
)
49+
50+
51+
if __name__ == '__main__':
52+
run()

0 commit comments

Comments
 (0)