Skip to content

Commit 7849875

Browse files
authored
Swap in faster uint6 bitpacking function
Differential Revision: D64504890 Pull Request resolved: #1098
1 parent 3103e7e commit 7849875

File tree

2 files changed

+28
-313
lines changed

2 files changed

+28
-313
lines changed

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h

Lines changed: 28 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -22,221 +22,6 @@ namespace internal {
2222
TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values(
2323
uint8_t* packed,
2424
const uint8_t* unpacked) {
25-
// Given 4 unpacked uint6 values: 01abcd, 23efgh, 45ijkl, 67mnop
26-
// this function packs them as:
27-
// b54: 67|45|23|01 (to hold upper 2 bits on all values)
28-
// b3210_0: efgh|abcd (lower 4 bits for first 2 values)
29-
// b3210_1: mnop|ijkl (lower 4 bits for last 2 values)
30-
31-
// These are stored in packed as: b54, b3210_0, b3210_1
32-
//
33-
// Input is 4 bytes
34-
// Output is 6 * 4 bits/8 = 3 bytes
35-
36-
// b54
37-
packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
38-
((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
39-
40-
// b3210_0
41-
packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
42-
43-
// b3210_1
44-
packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
45-
}
46-
47-
TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values(
48-
uint8_t* unpacked,
49-
const uint8_t* packed) {
50-
// Unpacks data packed by pack_4_uint6_values
51-
//
52-
// Input is 24 bits = 3 bytes
53-
// Output is 4 bytes
54-
55-
uint8_t b54 = packed[0];
56-
uint8_t b3210_0 = packed[1];
57-
uint8_t b3210_1 = packed[2];
58-
59-
unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
60-
unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
61-
62-
unpacked[2] = (b54 & 48) | (b3210_1 & 15);
63-
unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
64-
}
65-
66-
TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values(
67-
uint8_t* packed,
68-
const uint8x16_t& unpacked0,
69-
const uint8x16_t& unpacked1) {
70-
// This function is a vectorized version of pack_8_uint6_values
71-
// To understand it, please see pack_8_uint6_values first.
72-
// Before each code section, there is a comment indicating the
73-
// code in pack_8_uint6_values that is being vectorized
74-
//
75-
// Input is 32 bytes
76-
// Output is 6*32= 192 bits = 24 bytes
77-
78-
uint8x8_t b54;
79-
uint8x8_t mask;
80-
81-
// // b54
82-
// packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
83-
// ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
84-
mask = vdup_n_u8(48);
85-
b54 = vshr_n_u8(vand_u8(vget_low_u8(unpacked0), mask), 4);
86-
b54 = vorr_u8(b54, vshr_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 2));
87-
88-
b54 = vorr_u8(b54, vand_u8(vget_low_u8(unpacked1), mask));
89-
b54 = vorr_u8(b54, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 2));
90-
91-
vst1_u8(packed, b54);
92-
93-
mask = vdup_n_u8(15);
94-
uint8x8_t b3210;
95-
96-
// b3210_0
97-
// packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
98-
b3210 = vand_u8(vget_low_u8(unpacked0), mask);
99-
b3210 = vorr_u8(b3210, vshl_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 4));
100-
vst1_u8(packed + 8, b3210);
101-
102-
// b3210_1
103-
// packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
104-
b3210 = vand_u8(vget_low_u8(unpacked1), mask);
105-
b3210 = vorr_u8(b3210, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 4));
106-
vst1_u8(packed + 16, b3210);
107-
}
108-
109-
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
110-
uint8x16_t& unpacked0,
111-
uint8x16_t& unpacked1,
112-
const uint8_t* packed) {
113-
// Unpacks data packed by pack_32_uint6_values
114-
//
115-
// This function vectorizes vec_unpack_4_uint6_values
116-
// To understand it, please see vec_unpack_4_uint6_values first.
117-
// Before each code section, there is a comment indicating the
118-
// code in vec_unpack_4_uint6_values that is being vectorized
119-
120-
// Input is 24 bytes
121-
// Output is 32 bytes
122-
123-
uint8x8_t b54 = vld1_u8(packed);
124-
uint8x8_t b3210;
125-
uint8x8_t unpacked_tmp0;
126-
uint8x8_t unpacked_tmp1;
127-
128-
// unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
129-
// unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
130-
b3210 = vld1_u8(packed + 8);
131-
132-
unpacked_tmp0 = vshl_n_u8(vand_u8(b54, vdup_n_u8(3)), 4);
133-
unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b3210, vdup_n_u8(15)));
134-
135-
unpacked_tmp1 = vshl_n_u8(vand_u8(b54, vdup_n_u8(12)), 2);
136-
unpacked_tmp1 = vorr_u8(unpacked_tmp1, vshr_n_u8(b3210, 4));
137-
138-
unpacked0 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);
139-
140-
// unpacked[2] = (b54 & 48) | (b3210_1 & 15);
141-
// unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
142-
b3210 = vld1_u8(packed + 16);
143-
144-
unpacked_tmp0 = vand_u8(b54, vdup_n_u8(48));
145-
unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b3210, vdup_n_u8(15)));
146-
147-
unpacked_tmp1 = vshr_n_u8(vand_u8(b54, vdup_n_u8(192)), 2);
148-
unpacked_tmp1 = vorr_u8(unpacked_tmp1, vshr_n_u8(b3210, 4));
149-
150-
unpacked1 = vcombine_u8(unpacked_tmp0, unpacked_tmp1);
151-
}
152-
153-
TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values(
154-
uint8_t* packed,
155-
const uint8x16_t& unpacked0,
156-
const uint8x16_t& unpacked1,
157-
const uint8x16_t& unpacked2,
158-
const uint8x16_t& unpacked3) {
159-
// This function is a vectorized version of pack_4_uint6_values
160-
// To understand it, please see pack_4_uint6_values first.
161-
// Before each code section, there is a comment indicating the
162-
// code in pack_4_uint6_values that is being vectorized
163-
//
164-
// Input is 64 bytes
165-
// Output is 6*64= 384 bits = 48 bytes
166-
167-
uint8x16_t b54;
168-
uint8x16_t mask;
169-
170-
// b54
171-
// packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
172-
// ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
173-
mask = vdupq_n_u8(48);
174-
b54 = vshrq_n_u8(vandq_u8(unpacked0, mask), 4);
175-
b54 = vorrq_u8(b54, vshrq_n_u8(vandq_u8(unpacked1, mask), 2));
176-
b54 = vorrq_u8(b54, vandq_u8(unpacked2, mask));
177-
b54 = vorrq_u8(b54, vshlq_n_u8(vandq_u8(unpacked3, mask), 2));
178-
179-
vst1q_u8(packed, b54);
180-
181-
mask = vdupq_n_u8(15);
182-
uint8x16_t b3210;
183-
184-
// b3210_0
185-
// packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
186-
b3210 = vandq_u8(unpacked0, mask);
187-
b3210 = vorrq_u8(b3210, vshlq_n_u8(vandq_u8(unpacked1, mask), 4));
188-
vst1q_u8(packed + 16, b3210);
189-
190-
// b3210_1
191-
// packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
192-
b3210 = vandq_u8(unpacked2, mask);
193-
b3210 = vorrq_u8(b3210, vshlq_n_u8(vandq_u8(unpacked3, mask), 4));
194-
vst1q_u8(packed + 32, b3210);
195-
}
196-
197-
TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
198-
uint8x16_t& unpacked0,
199-
uint8x16_t& unpacked1,
200-
uint8x16_t& unpacked2,
201-
uint8x16_t& unpacked3,
202-
const uint8_t* packed) {
203-
// Unpacks data packed by pack_64_uint6_values
204-
//
205-
// This function vectorizes vec_unpack_4_uint6_values
206-
// To understand it, please see vec_unpack_4_uint6_values first.
207-
// Before each code section, there is a comment indicating the
208-
// code in vec_unpack_4_uint6_values that is being vectorized
209-
210-
// Input is 48 bytes
211-
// Output is 64 bytes
212-
213-
uint8x16_t b54 = vld1q_u8(packed);
214-
uint8x16_t b3210;
215-
216-
// unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
217-
// unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
218-
b3210 = vld1q_u8(packed + 16);
219-
220-
unpacked0 = vshlq_n_u8(vandq_u8(b54, vdupq_n_u8(3)), 4);
221-
unpacked0 = vorrq_u8(unpacked0, vandq_u8(b3210, vdupq_n_u8(15)));
222-
223-
unpacked1 = vshlq_n_u8(vandq_u8(b54, vdupq_n_u8(12)), 2);
224-
unpacked1 = vorrq_u8(unpacked1, vshrq_n_u8(b3210, 4));
225-
226-
// unpacked[2] = (b54 & 48) | (b3210_1 & 15);
227-
// unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
228-
b3210 = vld1q_u8(packed + 32);
229-
230-
unpacked2 = vandq_u8(b54, vdupq_n_u8(48));
231-
unpacked2 = vorrq_u8(unpacked2, vandq_u8(b3210, vdupq_n_u8(15)));
232-
233-
unpacked3 = vshrq_n_u8(vandq_u8(b54, vdupq_n_u8(192)), 2);
234-
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4));
235-
}
236-
237-
TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
238-
uint8_t* packed,
239-
const uint8_t* unpacked) {
24025
// Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
24126
// this function packs them as:
24227
// packed[0]: 56 | abcdef
@@ -254,9 +39,9 @@ TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
25439
packed[2] |= ((unpacked[3] & 0b11'0000u) << 2);
25540
}
25641

257-
TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
258-
uint8_t* unpacked,
259-
const uint8_t* packed) {
42+
TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values(
43+
uint8_t* unpacked,
44+
const uint8_t* packed) {
26045
// Unpacks data packed by pack_4_uint6_values_v2
26146
//
26247
// Input is 24 bits = 3 bytes
@@ -266,17 +51,17 @@ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
26651
unpacked[2] = packed[2] & 0b111111u;
26752
// Last value is packed in the upper 2 bits of the three bytes
26853
unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
269-
((packed[1] & 0b1100'0000u) >> 4) |
270-
((packed[2] & 0b1100'0000u) >> 2);
54+
((packed[1] & 0b1100'0000u) >> 4) | ((packed[2] & 0b1100'0000u) >> 2);
27155
}
27256

273-
TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
274-
uint8_t* packed,
275-
const uint8x16_t& unpacked0,
276-
const uint8x16_t& unpacked1) {
57+
TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values(
58+
uint8_t* packed,
59+
const uint8x16_t& unpacked0,
60+
const uint8x16_t& unpacked1) {
27761
// This function is a vectorized version of pack_4_uint6_values_v2.
278-
// To understand the following code, please see pack_4_uint6_values_v2 first and
279-
// consider the following mapping for the unpacked parameter of that function:
62+
// To understand the following code, please see pack_4_uint6_values_v2 first
63+
// and consider the following mapping for the unpacked parameter of that
64+
// function:
28065
//
28166
// unpacked[0] -> vget_low_u8(unpacked0)
28267
// unpacked[1] -> vget_high_u8(unpacked0)
@@ -293,23 +78,26 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
29378
// packed[0] = unpacked[0]
29479
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
29580
r = vget_low_u8(unpacked0);
296-
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6));
81+
r = vorr_u8(
82+
r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6));
29783
vst1_u8(packed, r);
29884

29985
// packed[1] = unpacked[1]
30086
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
30187
r = vget_high_u8(unpacked0);
302-
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4));
88+
r = vorr_u8(
89+
r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4));
30390
vst1_u8(packed + 8, r);
30491

30592
// packed[2] = unpacked[2]
30693
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
30794
r = vget_low_u8(unpacked1);
308-
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2));
95+
r = vorr_u8(
96+
r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2));
30997
vst1_u8(packed + 16, r);
31098
}
31199

312-
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
100+
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
313101
uint8x16_t& unpacked0,
314102
uint8x16_t& unpacked1,
315103
const uint8_t* packed) {
@@ -331,18 +119,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
331119
// ((packed[2] & 0b1100'0000u) >> 2);
332120
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
333121
uint8x8_t unpacked3;
334-
unpacked3 = vorr_u8(vshr_n_u8(vand_u8(packed0, high), 6),
335-
vshr_n_u8(vand_u8(packed1, high), 4));
336-
unpacked3 = vorr_u8(unpacked3,
337-
vshr_n_u8(vand_u8(packed2, high), 2));
122+
unpacked3 = vorr_u8(
123+
vshr_n_u8(vand_u8(packed0, high), 6),
124+
vshr_n_u8(vand_u8(packed1, high), 4));
125+
unpacked3 = vorr_u8(unpacked3, vshr_n_u8(vand_u8(packed2, high), 2));
338126

339127
// unpacked[i] = packed[i] & 0b11'1111u;
340128
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
341129
unpacked0 = vcombine_u8(vand_u8(packed0, mask), vand_u8(packed1, mask));
342130
unpacked1 = vcombine_u8(vand_u8(packed2, mask), unpacked3);
343131
}
344132

345-
TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
133+
TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values(
346134
uint8_t* packed,
347135
const uint8x16_t& unpacked0,
348136
const uint8x16_t& unpacked1,
@@ -376,7 +164,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
376164
vst1q_u8(packed + 32, r);
377165
}
378166

379-
TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
167+
TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
380168
uint8x16_t& unpacked0,
381169
uint8x16_t& unpacked1,
382170
uint8x16_t& unpacked2,
@@ -399,10 +187,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
399187
// ((packed[1] & 0b1100'0000u) >> 4) |
400188
// ((packed[2] & 0b1100'0000u) >> 2);
401189
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
402-
unpacked3 = vorrq_u8(vshrq_n_u8(vandq_u8(unpacked0, high), 6),
403-
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
404-
unpacked3 = vorrq_u8(unpacked3,
405-
vshrq_n_u8(vandq_u8(unpacked2, high), 2));
190+
unpacked3 = vorrq_u8(
191+
vshrq_n_u8(vandq_u8(unpacked0, high), 6),
192+
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
193+
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(unpacked2, high), 2));
406194

407195
// unpacked[i] = packed[i] & 0b11'1111u;
408196
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);

0 commit comments

Comments
 (0)