Skip to content

Commit 422e190

Browse files
committed
Add fused_transpose_quant op
1 parent 670cbd9 commit 422e190

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
template <int VecSize>
13+
__device__ void BlockLoad(const phi::bfloat16* X,
14+
__nv_bfloat16 input[4][4],
15+
size_t M,
16+
size_t K) {
17+
for (size_t i = 0; i < 4; i++) {
18+
size_t off_n = blockIdx.z;
19+
size_t off_m = blockIdx.y * 128 + threadIdx.y + i * 32;
20+
size_t off_k = blockIdx.x * 128 + threadIdx.x * VecSize;
21+
size_t offset = (off_n * M + off_m) * K + off_k;
22+
23+
for (size_t j = 0; j < 4; j += VecSize) {
24+
if (off_k + j * 32 < K) {
25+
size_t idx = offset + j * 32;
26+
using LoadT = VecType<__nv_bfloat16, VecSize>;
27+
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
28+
for (int k = 0; k < VecSize; k++) {
29+
input[i][j + k] = data[k];
30+
}
31+
}
32+
}
33+
}
34+
}
35+
36+
__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
37+
__nv_bfloat16 amax[4],
38+
__nv_bfloat16* shm) {
39+
// Reduce [(4), 32, 32, 4] => [32, 32, 4]
40+
__nv_bfloat16 warp_max[4];
41+
for (int i = 0; i < 4; i++) {
42+
for (int j = 0; j < 4; j++) {
43+
__nv_bfloat16 t = __habs(input[i][j]);
44+
warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t);
45+
}
46+
}
47+
48+
// Reduce [(32), 32, 4] => [32, 4]
49+
for (int i = 0; i < 4; i++) {
50+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
51+
}
52+
__syncthreads();
53+
for (int offset = 16; offset > 0; offset /= 2) {
54+
if (threadIdx.y < offset) {
55+
for (int i = 0; i < 4; i++) {
56+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
57+
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
58+
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
59+
}
60+
}
61+
__syncthreads();
62+
}
63+
64+
for (int i = 0; i < 4; i++) {
65+
amax[i] = shm[i * 32 + threadIdx.x];
66+
}
67+
}
68+
69+
template <typename OutT, int VecSize>
70+
__device__ void BlockStoreScale(float* scale,
71+
__nv_bfloat16 amax[4],
72+
float scale_inv[4],
73+
size_t M,
74+
size_t K) {
75+
float scale_out[4];
76+
for (int i = 0; i < 4; i++) {
77+
scale_out[i] = ComputeScale<__nv_bfloat16, OutT>(amax[i], 0.0f);
78+
scale_inv[i] = __frcp_rn(scale_out[i]);
79+
}
80+
if (threadIdx.y == 0) {
81+
size_t off_n = blockIdx.z;
82+
size_t off_m = blockIdx.y;
83+
size_t off_k = blockIdx.x * 128 + threadIdx.x * VecSize;
84+
size_t offset = (off_n * (M / 128) + off_m) * K + off_k;
85+
86+
for (size_t j = 0; j < 4; j += VecSize) {
87+
if (off_k + j * 32 < K) {
88+
size_t idx = offset + j * 32;
89+
using StoreT = VecType<float, VecSize>;
90+
StoreT data;
91+
for (int k = 0; k < VecSize; k++) {
92+
data[k] = scale_out[j + k];
93+
}
94+
*reinterpret_cast<StoreT*>(scale + idx) = data;
95+
}
96+
}
97+
}
98+
}
99+
100+
template <typename OutT, int VecSize>
101+
__device__ void BlockStoreOut(OutT* out,
102+
const OutT shm[128][129],
103+
size_t M,
104+
size_t K) {
105+
for (size_t i = 0; i < 4; i++) {
106+
size_t idx_n = blockIdx.z;
107+
size_t idx_k = blockIdx.x * 128 + threadIdx.y + i * 32;
108+
size_t idx_m = blockIdx.y * 128 + threadIdx.x * 4;
109+
size_t idx = (idx_n * K + idx_k) * M + idx_m;
110+
111+
if (idx_k < K) {
112+
using StoreT = VecType<OutT, VecSize>;
113+
StoreT data;
114+
for (int j = 0; j < VecSize; j++) {
115+
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
116+
}
117+
*reinterpret_cast<StoreT*>(out + idx) = data;
118+
}
119+
}
120+
}
121+
122+
template <typename OutT, int VecSize>
123+
__global__ void __launch_bounds__(1024, 2)
124+
FusedTransposeQuantKernel(const phi::bfloat16* __restrict__ X,
125+
OutT* __restrict__ out,
126+
float* __restrict__ scale,
127+
size_t M,
128+
size_t K) {
129+
__shared__ OutT shm[128][129];
130+
131+
// Load 128x128 elements from X
132+
__nv_bfloat16 input[4][4];
133+
BlockLoad<VecSize>(X, input, M, K);
134+
135+
// Find the maximum of each 128 elements on the M axis
136+
__nv_bfloat16 amax[4];
137+
BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm));
138+
139+
// Compute scale and scale_inv, save scale to output
140+
float scale_inv[4];
141+
BlockStoreScale<OutT, VecSize>(scale, amax, scale_inv, M, K);
142+
143+
// Scale X and save into shared memory with transposed layout
144+
for (int i = 0; i < 4; i++) {
145+
for (int j = 0; j < 4; j += VecSize) {
146+
for (int k = 0; k < VecSize; k++) {
147+
float input_fp32 = static_cast<float>(input[i][j + k]);
148+
float output_scaled = input_fp32 * scale_inv[j + k];
149+
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
150+
static_cast<OutT>(output_scaled);
151+
}
152+
}
153+
}
154+
__syncthreads();
155+
156+
// Store 128x128 elements back
157+
// Note: out is always 4x vectorizable.
158+
BlockStoreOut<OutT, 4>(out, shm, M, K);
159+
}
160+
161+
/**
162+
* Doing quantization on dim[-2] of X, then transpose dim[-1] and dim[-2] of X.
163+
*
164+
* Inputs:
165+
* X : [*, M, K], bfloat16
166+
*
167+
* Outputs:
168+
* out : [*, K, M], float8_e4m3fn
169+
* scale: [*, M/128, K], float32
170+
*
171+
* Requirements:
172+
* 1) batch_size <= 65535
173+
* 2) M <= 65535 * 128 and M % 128 == 0
174+
*/
175+
std::vector<paddle::Tensor> fused_transpose_quant(const paddle::Tensor& X) {
176+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
177+
178+
std::vector<int64_t> shape = X.shape();
179+
PD_CHECK(shape.size() >= 2);
180+
181+
int64_t M = shape[shape.size() - 2];
182+
int64_t K = shape[shape.size() - 1];
183+
int64_t N = X.numel() / (M * K);
184+
185+
PADDLE_ENFORCE_LE(
186+
N,
187+
65535,
188+
common::errors::InvalidArgument("The batch size (X.shape[0:-2] in total) "
189+
"must be no larger than 65535."));
190+
PADDLE_ENFORCE_LE(M,
191+
65535 * 128,
192+
common::errors::InvalidArgument(
193+
"X.shape[-2] must be no larger than 65535 * 128."));
194+
PADDLE_ENFORCE_EQ(
195+
M % 128,
196+
0,
197+
common::errors::InvalidArgument("X.shape[-2] must be multiple of 128."));
198+
199+
// Allocate for out and scale
200+
std::vector<int64_t> out_shape = shape;
201+
out_shape[shape.size() - 2] = K;
202+
out_shape[shape.size() - 1] = M;
203+
paddle::Tensor out =
204+
paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, X.place());
205+
206+
std::vector<int64_t> scale_shape = shape;
207+
scale_shape[shape.size() - 2] = M / 128;
208+
paddle::Tensor scale =
209+
paddle::empty(scale_shape, paddle::DataType::FLOAT32, X.place());
210+
211+
// Skip 0-size
212+
if (N == 0 || M == 0 || K == 0) {
213+
return {out, scale};
214+
}
215+
216+
// Launch kernel
217+
dim3 grid((K + 127) / 128, M / 128, N);
218+
dim3 block(32, 32);
219+
220+
#define LAUNCH_KERNEL(VEC_SIZE) \
221+
FusedTransposeQuantKernel<phi::float8_e4m3fn, VEC_SIZE> \
222+
<<<grid, block>>>(X.data<phi::bfloat16>(), \
223+
out.data<phi::float8_e4m3fn>(), \
224+
scale.data<float>(), \
225+
M, \
226+
K);
227+
if (K % 4 == 0) {
228+
LAUNCH_KERNEL(4);
229+
} else if (K % 2 == 0) {
230+
LAUNCH_KERNEL(2);
231+
} else {
232+
LAUNCH_KERNEL(1);
233+
}
234+
#undef LAUNCH_KERNEL
235+
236+
return {out, scale};
237+
}
238+
239+
PD_BUILD_OP(fused_transpose_quant)
240+
.Inputs({"X"})
241+
.Outputs({"output", "scale"})
242+
.SetKernelFn(PD_KERNEL(fused_transpose_quant));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setup_fused_quant_ops():
4141
"fused_quanted_ops/fused_act_dequant.cu",
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
44+
"fused_quanted_ops/fused_transpose_quant.cu",
4445
],
4546
extra_compile_args={
4647
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_transpose_quant(out, scale):
8+
out = out.transpose([0, 2, 1]).astype('float32')
9+
scale = paddle.repeat_interleave(scale, repeats=128, axis=1)
10+
x = out * scale
11+
return x
12+
13+
14+
def test_fused_transpose_quant(batch_size, seq_len, hidden_size):
15+
print(batch_size, seq_len, hidden_size)
16+
x = paddle.randn([batch_size, seq_len, hidden_size], dtype='bfloat16')
17+
x = paddle.clip(x, min=-50, max=50)
18+
19+
out, scale = FQO.fused_transpose_quant(x)
20+
21+
x_fp32 = x.astype('float32')
22+
x_restored = restore_transpose_quant(out, scale)
23+
24+
np.testing.assert_allclose(
25+
x_fp32, x_restored, rtol=0.01, atol=0.3
26+
) # 存在截断误差,atol=0.3,通常在1e-6
27+
28+
29+
def run():
30+
for batch_size in [1, 4]:
31+
for seq_len in [2048, 7168]:
32+
for hidden_size in [1, 257, 2114, 4096]:
33+
test_fused_transpose_quant(batch_size, seq_len, hidden_size)
34+
35+
36+
if __name__ == "__main__":
37+
run()

0 commit comments

Comments
 (0)