@@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
114
114
uint8x8_t packed1 = vld1_u8 (packed + 8 );
115
115
uint8x8_t packed2 = vld1_u8 (packed + 16 );
116
116
117
- // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
118
- // ((packed[1] & 0b1100'0000u) >> 4) |
119
- // ((packed[2] & 0b1100'0000u) >> 2);
120
- const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
121
117
uint8x8_t unpacked3;
122
- unpacked3 = vorr_u8 (
123
- vshr_n_u8 (vand_u8 (packed0, high), 6 ),
124
- vshr_n_u8 (vand_u8 (packed1, high), 4 ));
125
- unpacked3 = vorr_u8 (unpacked3, vshr_n_u8 (vand_u8 (packed2, high), 2 ));
118
+ // We want to extract bits 123456 and place them in unpacked3.
119
+ // Packed structure is:
120
+ //
121
+ // packed0: 56 | abcdef
122
+ // packed1: 34 | ghijkl
123
+ // packed2: 12 | mnopqr
124
+ //
125
+ // unpacked3 = 1234 ghij
126
+ unpacked3 = vsri_n_u8 (packed2, packed1, 2 );
127
+ // unpacked3 = 1234 56ab
128
+ unpacked3 = vsri_n_u8 (unpacked3, packed0, 4 );
129
+ // unpacked3 = 0012 3456
130
+ unpacked3 = vshr_n_u8 (unpacked3, 2 );
126
131
127
132
// unpacked[i] = packed[i] & 0b11'1111u;
128
133
const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
@@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
183
188
unpacked1 = vld1q_u8 (packed + 16 );
184
189
unpacked2 = vld1q_u8 (packed + 32 );
185
190
186
- // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
187
- // ((packed[1] & 0b1100'0000u) >> 4) |
188
- // ((packed[2] & 0b1100'0000u) >> 2);
189
- const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
190
- unpacked3 = vorrq_u8 (
191
- vshrq_n_u8 (vandq_u8 (unpacked0, high), 6 ),
192
- vshrq_n_u8 (vandq_u8 (unpacked1, high), 4 ));
193
- unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
191
+ // We want to extract bits 123456 and place them in unpacked3.
192
+ // Packed structure is:
193
+ //
194
+ // packed0: 56 | abcdef
195
+ // packed1: 34 | ghijkl
196
+ // packed2: 12 | mnopqr
197
+ //
198
+ // unpacked3 = 1234 ghij
199
+ unpacked3 = vsriq_n_u8 (unpacked2, unpacked1, 2 );
200
+ // unpacked3 = 1234 56ab
201
+ unpacked3 = vsriq_n_u8 (unpacked3, unpacked0, 4 );
202
+ // unpacked3 = 0012 3456
203
+ unpacked3 = vshrq_n_u8 (unpacked3, 2 );
194
204
195
205
// unpacked[i] = packed[i] & 0b11'1111u;
196
206
const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
0 commit comments