Skip to content

Commit 3475aed

Browse files
authored
Use fewer instructions when unpacking uint6s.
Differential Revision: D64548639 Pull Request resolved: #1109
1 parent 6653b45 commit 3475aed

File tree

1 file changed

+26
-16
lines changed
  • torchao/experimental/kernels/cpu/aarch64/bitpacking

1 file changed

+26
-16
lines changed

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
114114
uint8x8_t packed1 = vld1_u8(packed + 8);
115115
uint8x8_t packed2 = vld1_u8(packed + 16);
116116

117-
// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
118-
// ((packed[1] & 0b1100'0000u) >> 4) |
119-
// ((packed[2] & 0b1100'0000u) >> 2);
120-
const uint8x8_t high = vdup_n_u8(0b1100'0000u);
121117
uint8x8_t unpacked3;
122-
unpacked3 = vorr_u8(
123-
vshr_n_u8(vand_u8(packed0, high), 6),
124-
vshr_n_u8(vand_u8(packed1, high), 4));
125-
unpacked3 = vorr_u8(unpacked3, vshr_n_u8(vand_u8(packed2, high), 2));
118+
// We want to extract bits 123456 and place them in unpacked3.
119+
// Packed structure is:
120+
//
121+
// packed0: 56 | abcdef
122+
// packed1: 34 | ghijkl
123+
// packed2: 12 | mnopqr
124+
//
125+
// unpacked3 = 1234 ghij
126+
unpacked3 = vsri_n_u8(packed2, packed1, 2);
127+
// unpacked3 = 1234 56ab
128+
unpacked3 = vsri_n_u8(unpacked3, packed0, 4);
129+
// unpacked3 = 0012 3456
130+
unpacked3 = vshr_n_u8(unpacked3, 2);
126131

127132
// unpacked[i] = packed[i] & 0b11'1111u;
128133
const uint8x8_t mask = vdup_n_u8(0b11'1111u);
@@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
183188
unpacked1 = vld1q_u8(packed + 16);
184189
unpacked2 = vld1q_u8(packed + 32);
185190

186-
// unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
187-
// ((packed[1] & 0b1100'0000u) >> 4) |
188-
// ((packed[2] & 0b1100'0000u) >> 2);
189-
const uint8x16_t high = vdupq_n_u8(0b1100'0000u);
190-
unpacked3 = vorrq_u8(
191-
vshrq_n_u8(vandq_u8(unpacked0, high), 6),
192-
vshrq_n_u8(vandq_u8(unpacked1, high), 4));
193-
unpacked3 = vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(unpacked2, high), 2));
191+
// We want to extract bits 123456 and place them in unpacked3.
192+
// Packed structure is:
193+
//
194+
// packed0: 56 | abcdef
195+
// packed1: 34 | ghijkl
196+
// packed2: 12 | mnopqr
197+
//
198+
// unpacked3 = 1234 ghij
199+
unpacked3 = vsriq_n_u8(unpacked2, unpacked1, 2);
200+
// unpacked3 = 1234 56ab
201+
unpacked3 = vsriq_n_u8(unpacked3, unpacked0, 4);
202+
// unpacked3 = 0012 3456
203+
unpacked3 = vshrq_n_u8(unpacked3, 2);
194204

195205
// unpacked[i] = packed[i] & 0b11'1111u;
196206
const uint8x16_t mask = vdupq_n_u8(0b11'1111u);

0 commit comments

Comments
 (0)