Skip to content

Commit 893cafe

Browse files
authored
Experimental 6-bit quantization for Llama in torchchat
Differential Revision: D64437228 Pull Request resolved: #1094
1 parent ce4822b commit 893cafe

File tree

2 files changed

+250
-0
lines changed

2 files changed

+250
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
234234
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(b3210, 4));
235235
}
236236

237+
TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
238+
uint8_t* packed,
239+
const uint8_t* unpacked) {
240+
// Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
241+
// this function packs them as:
242+
// packed[0]: 56 | abcdef
243+
// packed[1]: 34 | ghijkl
244+
// packed[2]: 12 | mnopqr
245+
//
246+
// Input is 4 bytes
247+
// Output is 6 * 4 bits/8 = 3 bytes
248+
packed[0] = unpacked[0];
249+
packed[1] = unpacked[1];
250+
packed[2] = unpacked[2];
251+
// Last value is packed in the upper 2 bits of the three bytes
252+
packed[0] |= ((unpacked[3] & 0b00'0011u) << 6);
253+
packed[1] |= ((unpacked[3] & 0b00'1100u) << 4);
254+
packed[2] |= ((unpacked[3] & 0b11'0000u) << 2);
255+
}
256+
257+
TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
258+
uint8_t* unpacked,
259+
const uint8_t* packed) {
260+
// Unpacks data packed by pack_4_uint6_values_v2
261+
//
262+
// Input is 24 bits = 3 bytes
263+
// Output is 4 bytes
264+
unpacked[0] = packed[0] & 0b111111u;
265+
unpacked[1] = packed[1] & 0b111111u;
266+
unpacked[2] = packed[2] & 0b111111u;
267+
// Last value is packed in the upper 2 bits of the three bytes
268+
unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
269+
((packed[1] & 0b1100'0000u) >> 4) |
270+
((packed[2] & 0b1100'0000u) >> 2);
271+
}
272+
273+
TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
274+
uint8_t* packed,
275+
const uint8x16_t& unpacked0,
276+
const uint8x16_t& unpacked1) {
277+
// This function is a vectorized version of pack_4_uint6_values_v2.
278+
// To understand the following code, please see pack_4_uint6_values_v2 first and
279+
// consider the following mapping for the unpacked parameter of that function:
280+
//
281+
// unpacked[0] -> vget_low_u8(unpacked0)
282+
// unpacked[1] -> vget_high_u8(unpacked0)
283+
// unpacked[2] -> vget_low_u8(unpacked1)
284+
// unpacked[3] -> vget_high_u8(unpacked1)
285+
//
286+
// Before each code section, there is a comment indicating the
287+
// code in pack_4_uint6_values_v2 that is being vectorized.
288+
//
289+
// Input is 32 bytes.
290+
// Output is 6*32= 192 bits = 24 bytes.
291+
uint8x8_t r;
292+
293+
// packed[0] = unpacked[0]
294+
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
295+
r = vget_low_u8(unpacked0);
296+
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'0011u)), 6));
297+
vst1_u8(packed, r);
298+
299+
// packed[1] = unpacked[1]
300+
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
301+
r = vget_high_u8(unpacked0);
302+
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b00'1100u)), 4));
303+
vst1_u8(packed + 8, r);
304+
305+
// packed[2] = unpacked[2]
306+
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
307+
r = vget_low_u8(unpacked1);
308+
r = vorr_u8(r, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), vdup_n_u8(0b11'0000u)), 2));
309+
vst1_u8(packed + 16, r);
310+
}
311+
312+
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
313+
uint8x16_t& unpacked0,
314+
uint8x16_t& unpacked1,
315+
const uint8_t* packed) {
316+
// Unpacks data packed by vec_pack_32_uint6_values_v2.
317+
//
318+
// This function vectorizes unpack_4_uint6_values_v2.
319+
// To understand it, please see unpack_4_uint6_values_v2 first.
320+
// Before each code section, there is a comment indicating the
321+
// code in unpack_4_uint6_values_v2 that is being vectorized.
322+
//
323+
// Input is 24 bytes.
324+
// Output is 32 bytes.
325+
uint8x8_t packed0 = vld1_u8(packed);
326+
uint8x8_t packed1 = vld1_u8(packed + 8);
327+
uint8x8_t packed2 = vld1_u8(packed + 16);
328+
329+
// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
330+
// ((packed[1] & 0b1100'0000u) >> 4) |
331+
// ((packed[2] & 0b1100'0000u) >> 2);
332+
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
333+
uint8x8_t unpacked3;
334+
unpacked3 = vorr_u8(vshr_n_u8(vand_u8(packed0, high), 6),
335+
vshr_n_u8(vand_u8(packed1, high), 4));
336+
unpacked3 = vorr_u8(unpacked3,
337+
vshr_n_u8(vand_u8(packed2, high), 2));
338+
339+
// unpacked[i] = packed[i] & 0b11'1111u;
340+
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
341+
unpacked0 = vcombine_u8(vand_u8(packed0, mask), vand_u8(packed1, mask));
342+
unpacked1 = vcombine_u8(vand_u8(packed2, mask), unpacked3);
343+
}
344+
345+
TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
346+
uint8_t* packed,
347+
const uint8x16_t& unpacked0,
348+
const uint8x16_t& unpacked1,
349+
const uint8x16_t& unpacked2,
350+
const uint8x16_t& unpacked3) {
351+
// This function is a vectorized version of pack_4_uint6_values_v2.
352+
// To understand the following code, please see pack_4_uint6_values_v2 first.
353+
// Before each code section, there is a comment indicating the
354+
// code in pack_4_uint6_values_v2 that is being vectorized.
355+
//
356+
// Input is 48 bytes.
357+
// Output is 64 bytes.
358+
uint8x16_t r;
359+
360+
// packed[0] = unpacked[0]
361+
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
362+
r = unpacked0;
363+
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'0011u)), 6));
364+
vst1q_u8(packed, r);
365+
366+
// packed[1] = unpacked[1]
367+
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
368+
r = unpacked1;
369+
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b00'1100u)), 4));
370+
vst1q_u8(packed + 16, r);
371+
372+
// packed[2] = unpacked[2]
373+
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
374+
r = unpacked2;
375+
r = vorrq_u8(r, vshlq_n_u8(vandq_u8(unpacked3, vdupq_n_u8(0b11'0000u)), 2));
376+
vst1q_u8(packed + 32, r);
377+
}
378+
379+
TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
380+
uint8x16_t& unpacked0,
381+
uint8x16_t& unpacked1,
382+
uint8x16_t& unpacked2,
383+
uint8x16_t& unpacked3,
384+
const uint8_t* packed) {
385+
// Unpacks data packed by vec_pack_64_uint6_values_v2.
386+
//
387+
// This function vectorizes unpack_4_uint6_values_v2.
388+
// To understand it, please see unpack_4_uint6_values_v2 first.
389+
// Before each code section, there is a comment indicating the
390+
// code in unpack_4_uint6_values that is being vectorized
391+
392+
// Input is 48 bytes.
393+
// Output is 64 bytes.
394+
unpacked0 = vld1q_u8(packed);
395+
unpacked1 = vld1q_u8(packed + 16);
396+
unpacked2 = vld1q_u8(packed + 32);
397+
398+
// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
399+
// ((packed[1] & 0b1100'0000u) >> 4) |
400+
// ((packed[2] & 0b1100'0000u) >> 2);
401+
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
402+
unpacked3 = vorrq_u8(vshrq_n_u8(vandq_u8(unpacked0, high), 6),
403+
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
404+
unpacked3 = vorrq_u8(unpacked3,
405+
vshrq_n_u8(vandq_u8(unpacked2, high), 2));
406+
407+
// unpacked[i] = packed[i] & 0b11'1111u;
408+
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);
409+
unpacked0 = vandq_u8(unpacked0, mask);
410+
unpacked1 = vandq_u8(unpacked1, mask);
411+
unpacked2 = vandq_u8(unpacked2, mask);
412+
}
413+
237414
} // namespace internal
238415
} // namespace bitpacking
239416
} // namespace torchao

torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,23 @@ TEST(test_bitpacking_4_uint6_values, PackUnpackAreSame) {
504504
}
505505
}
506506

507+
TEST(test_bitpacking_4_uint6_values_v2, PackUnpackAreSame) {
508+
int unpacked_bytes = 4;
509+
int packed_bytes = 3;
510+
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
511+
std::vector<uint8_t> packed(packed_bytes, 0);
512+
std::vector<uint8_t> unpacked(unpacked_bytes, 0);
513+
514+
torchao::bitpacking::internal::pack_4_uint6_values_v2(
515+
packed.data(), input.data());
516+
torchao::bitpacking::internal::unpack_4_uint6_values_v2(
517+
unpacked.data(), packed.data());
518+
for (int i = 0; i < unpacked_bytes; ++i) {
519+
EXPECT_EQ(input[i], unpacked[i]);
520+
}
521+
}
522+
523+
507524
TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
508525
int unpacked_bytes = 32;
509526
int packed_bytes = 24;
@@ -529,6 +546,31 @@ TEST(test_bitpacking_32_uint6_values, PackUnpackAreSame) {
529546
}
530547
}
531548

549+
TEST(test_bitpacking_32_uint6_values_v2, PackUnpackAreSame) {
550+
int unpacked_bytes = 32;
551+
int packed_bytes = 24;
552+
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
553+
std::vector<uint8_t> packed(packed_bytes, 0);
554+
555+
uint8x16_t input0;
556+
uint8x16_t input1;
557+
558+
uint8x16_t unpacked0;
559+
uint8x16_t unpacked1;
560+
561+
input0 = vld1q_u8(input.data());
562+
input1 = vld1q_u8(input.data() + 16);
563+
torchao::bitpacking::internal::vec_pack_32_uint6_values_v2(
564+
packed.data(), input0, input1);
565+
torchao::bitpacking::internal::vec_unpack_32_uint6_values_v2(
566+
unpacked0, unpacked1, packed.data());
567+
568+
for (int i = 0; i < 16; ++i) {
569+
EXPECT_EQ(input0[i], unpacked0[i]);
570+
EXPECT_EQ(input1[i], unpacked1[i]);
571+
}
572+
}
573+
532574
TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
533575
int unpacked_bytes = 64;
534576
int packed_bytes = 48;
@@ -560,6 +602,37 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) {
560602
}
561603
}
562604

605+
TEST(test_bitpacking_64_uint6_values_v2, PackUnpackAreSame) {
606+
int unpacked_bytes = 64;
607+
int packed_bytes = 48;
608+
auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6);
609+
std::vector<uint8_t> packed(packed_bytes, 0);
610+
611+
uint8x16_t input0;
612+
uint8x16_t input1;
613+
uint8x16_t input2;
614+
uint8x16_t input3;
615+
616+
uint8x16_t unpacked0;
617+
uint8x16_t unpacked1;
618+
uint8x16_t unpacked2;
619+
uint8x16_t unpacked3;
620+
621+
torchao::bitpacking::internal::vec_load_64_uint8_values(
622+
input0, input1, input2, input3, input.data());
623+
torchao::bitpacking::internal::vec_pack_64_uint6_values_v2(
624+
packed.data(), input0, input1, input2, input3);
625+
torchao::bitpacking::internal::vec_unpack_64_uint6_values_v2(
626+
unpacked0, unpacked1, unpacked2, unpacked3, packed.data());
627+
628+
for (int i = 0; i < 16; ++i) {
629+
EXPECT_EQ(input0[i], unpacked0[i]);
630+
EXPECT_EQ(input1[i], unpacked1[i]);
631+
EXPECT_EQ(input2[i], unpacked2[i]);
632+
EXPECT_EQ(input3[i], unpacked3[i]);
633+
}
634+
}
635+
563636
// Universal bitpacking tests
564637
template <int nbit>
565638
void test_bitpacking_32_lowbit_values() {

0 commit comments

Comments
 (0)