@@ -22,221 +22,6 @@ namespace internal {
22
22
TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values (
23
23
uint8_t * packed,
24
24
const uint8_t * unpacked) {
25
- // Given 4 unpacked uint6 values: 01abcd, 23efgh, 45ijkl, 67mnop
26
- // this function packs them as:
27
- // b54: 67|45|23|01 (to hold upper 2 bits on all values)
28
- // b3210_0: efgh|abcd (lower 4 bits for first 2 values)
29
- // b3210_1: mnop|ijkl (lower 4 bits for last 2 values)
30
-
31
- // These are stored in packed as: b54, b3210_0, b3210_1
32
- //
33
- // Input is 4 bytes
34
- // Output is 6 * 4 bits/8 = 3 bytes
35
-
36
- // b54
37
- packed[0 ] = ((unpacked[0 ] & 48 ) >> 4 ) | ((unpacked[1 ] & 48 ) >> 2 ) |
38
- ((unpacked[2 ] & 48 )) | ((unpacked[3 ] & 48 ) << 2 );
39
-
40
- // b3210_0
41
- packed[1 ] = (unpacked[0 ] & 15 ) | ((unpacked[1 ] & 15 ) << 4 );
42
-
43
- // b3210_1
44
- packed[2 ] = (unpacked[2 ] & 15 ) | ((unpacked[3 ] & 15 ) << 4 );
45
- }
46
-
47
- TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values (
48
- uint8_t * unpacked,
49
- const uint8_t * packed) {
50
- // Unpacks data packed by pack_4_uint6_values
51
- //
52
- // Input is 24 bits = 3 bytes
53
- // Output is 4 bytes
54
-
55
- uint8_t b54 = packed[0 ];
56
- uint8_t b3210_0 = packed[1 ];
57
- uint8_t b3210_1 = packed[2 ];
58
-
59
- unpacked[0 ] = ((b54 & 3 ) << 4 ) | (b3210_0 & 15 );
60
- unpacked[1 ] = ((b54 & 12 ) << 2 ) | (b3210_0 >> 4 );
61
-
62
- unpacked[2 ] = (b54 & 48 ) | (b3210_1 & 15 );
63
- unpacked[3 ] = ((b54 & 192 ) >> 2 ) | (b3210_1 >> 4 );
64
- }
65
-
66
- TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values (
67
- uint8_t * packed,
68
- const uint8x16_t & unpacked0,
69
- const uint8x16_t & unpacked1) {
70
- // This function is a vectorized version of pack_8_uint6_values
71
- // To understand it, please see pack_8_uint6_values first.
72
- // Before each code section, there is a comment indicating the
73
- // code in pack_8_uint6_values that is being vectorized
74
- //
75
- // Input is 32 bytes
76
- // Output is 6*32= 192 bits = 24 bytes
77
-
78
- uint8x8_t b54;
79
- uint8x8_t mask;
80
-
81
- // // b54
82
- // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
83
- // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
84
- mask = vdup_n_u8 (48 );
85
- b54 = vshr_n_u8 (vand_u8 (vget_low_u8 (unpacked0), mask), 4 );
86
- b54 = vorr_u8 (b54, vshr_n_u8 (vand_u8 (vget_high_u8 (unpacked0), mask), 2 ));
87
-
88
- b54 = vorr_u8 (b54, vand_u8 (vget_low_u8 (unpacked1), mask));
89
- b54 = vorr_u8 (b54, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), mask), 2 ));
90
-
91
- vst1_u8 (packed, b54);
92
-
93
- mask = vdup_n_u8 (15 );
94
- uint8x8_t b3210;
95
-
96
- // b3210_0
97
- // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
98
- b3210 = vand_u8 (vget_low_u8 (unpacked0), mask);
99
- b3210 = vorr_u8 (b3210, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked0), mask), 4 ));
100
- vst1_u8 (packed + 8 , b3210);
101
-
102
- // b3210_1
103
- // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
104
- b3210 = vand_u8 (vget_low_u8 (unpacked1), mask);
105
- b3210 = vorr_u8 (b3210, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), mask), 4 ));
106
- vst1_u8 (packed + 16 , b3210);
107
- }
108
-
109
- TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values (
110
- uint8x16_t & unpacked0,
111
- uint8x16_t & unpacked1,
112
- const uint8_t * packed) {
113
- // Unpacks data packed by pack_32_uint6_values
114
- //
115
- // This function vectorizes vec_unpack_4_uint6_values
116
- // To understand it, please see vec_unpack_4_uint6_values first.
117
- // Before each code section, there is a comment indicating the
118
- // code in vec_unpack_4_uint6_values that is being vectorized
119
-
120
- // Input is 24 bytes
121
- // Output is 32 bytes
122
-
123
- uint8x8_t b54 = vld1_u8 (packed);
124
- uint8x8_t b3210;
125
- uint8x8_t unpacked_tmp0;
126
- uint8x8_t unpacked_tmp1;
127
-
128
- // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
129
- // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
130
- b3210 = vld1_u8 (packed + 8 );
131
-
132
- unpacked_tmp0 = vshl_n_u8 (vand_u8 (b54, vdup_n_u8 (3 )), 4 );
133
- unpacked_tmp0 = vorr_u8 (unpacked_tmp0, vand_u8 (b3210, vdup_n_u8 (15 )));
134
-
135
- unpacked_tmp1 = vshl_n_u8 (vand_u8 (b54, vdup_n_u8 (12 )), 2 );
136
- unpacked_tmp1 = vorr_u8 (unpacked_tmp1, vshr_n_u8 (b3210, 4 ));
137
-
138
- unpacked0 = vcombine_u8 (unpacked_tmp0, unpacked_tmp1);
139
-
140
- // unpacked[2] = (b54 & 48) | (b3210_1 & 15);
141
- // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
142
- b3210 = vld1_u8 (packed + 16 );
143
-
144
- unpacked_tmp0 = vand_u8 (b54, vdup_n_u8 (48 ));
145
- unpacked_tmp0 = vorr_u8 (unpacked_tmp0, vand_u8 (b3210, vdup_n_u8 (15 )));
146
-
147
- unpacked_tmp1 = vshr_n_u8 (vand_u8 (b54, vdup_n_u8 (192 )), 2 );
148
- unpacked_tmp1 = vorr_u8 (unpacked_tmp1, vshr_n_u8 (b3210, 4 ));
149
-
150
- unpacked1 = vcombine_u8 (unpacked_tmp0, unpacked_tmp1);
151
- }
152
-
153
- TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values (
154
- uint8_t * packed,
155
- const uint8x16_t & unpacked0,
156
- const uint8x16_t & unpacked1,
157
- const uint8x16_t & unpacked2,
158
- const uint8x16_t & unpacked3) {
159
- // This function is a vectorized version of pack_4_uint6_values
160
- // To understand it, please see pack_4_uint6_values first.
161
- // Before each code section, there is a comment indicating the
162
- // code in pack_4_uint6_values that is being vectorized
163
- //
164
- // Input is 64 bytes
165
- // Output is 6*64= 384 bits = 48 bytes
166
-
167
- uint8x16_t b54;
168
- uint8x16_t mask;
169
-
170
- // b54
171
- // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
172
- // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
173
- mask = vdupq_n_u8 (48 );
174
- b54 = vshrq_n_u8 (vandq_u8 (unpacked0, mask), 4 );
175
- b54 = vorrq_u8 (b54, vshrq_n_u8 (vandq_u8 (unpacked1, mask), 2 ));
176
- b54 = vorrq_u8 (b54, vandq_u8 (unpacked2, mask));
177
- b54 = vorrq_u8 (b54, vshlq_n_u8 (vandq_u8 (unpacked3, mask), 2 ));
178
-
179
- vst1q_u8 (packed, b54);
180
-
181
- mask = vdupq_n_u8 (15 );
182
- uint8x16_t b3210;
183
-
184
- // b3210_0
185
- // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
186
- b3210 = vandq_u8 (unpacked0, mask);
187
- b3210 = vorrq_u8 (b3210, vshlq_n_u8 (vandq_u8 (unpacked1, mask), 4 ));
188
- vst1q_u8 (packed + 16 , b3210);
189
-
190
- // b3210_1
191
- // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
192
- b3210 = vandq_u8 (unpacked2, mask);
193
- b3210 = vorrq_u8 (b3210, vshlq_n_u8 (vandq_u8 (unpacked3, mask), 4 ));
194
- vst1q_u8 (packed + 32 , b3210);
195
- }
196
-
197
- TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values (
198
- uint8x16_t & unpacked0,
199
- uint8x16_t & unpacked1,
200
- uint8x16_t & unpacked2,
201
- uint8x16_t & unpacked3,
202
- const uint8_t * packed) {
203
- // Unpacks data packed by pack_64_uint6_values
204
- //
205
- // This function vectorizes vec_unpack_4_uint6_values
206
- // To understand it, please see vec_unpack_4_uint6_values first.
207
- // Before each code section, there is a comment indicating the
208
- // code in vec_unpack_4_uint6_values that is being vectorized
209
-
210
- // Input is 48 bytes
211
- // Output is 64 bytes
212
-
213
- uint8x16_t b54 = vld1q_u8 (packed);
214
- uint8x16_t b3210;
215
-
216
- // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
217
- // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
218
- b3210 = vld1q_u8 (packed + 16 );
219
-
220
- unpacked0 = vshlq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (3 )), 4 );
221
- unpacked0 = vorrq_u8 (unpacked0, vandq_u8 (b3210, vdupq_n_u8 (15 )));
222
-
223
- unpacked1 = vshlq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (12 )), 2 );
224
- unpacked1 = vorrq_u8 (unpacked1, vshrq_n_u8 (b3210, 4 ));
225
-
226
- // unpacked[2] = (b54 & 48) | (b3210_1 & 15);
227
- // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
228
- b3210 = vld1q_u8 (packed + 32 );
229
-
230
- unpacked2 = vandq_u8 (b54, vdupq_n_u8 (48 ));
231
- unpacked2 = vorrq_u8 (unpacked2, vandq_u8 (b3210, vdupq_n_u8 (15 )));
232
-
233
- unpacked3 = vshrq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (192 )), 2 );
234
- unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (b3210, 4 ));
235
- }
236
-
237
- TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2 (
238
- uint8_t * packed,
239
- const uint8_t * unpacked) {
240
25
// Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
241
26
// this function packs them as:
242
27
// packed[0]: 56 | abcdef
@@ -254,9 +39,9 @@ TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
254
39
packed[2 ] |= ((unpacked[3 ] & 0b11'0000u ) << 2 );
255
40
}
256
41
257
- TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2 (
258
- uint8_t * unpacked,
259
- const uint8_t * packed) {
42
+ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values (
43
+ uint8_t * unpacked,
44
+ const uint8_t * packed) {
260
45
// Unpacks data packed by pack_4_uint6_values_v2
261
46
//
262
47
// Input is 24 bits = 3 bytes
@@ -266,17 +51,17 @@ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
266
51
unpacked[2 ] = packed[2 ] & 0b111111u ;
267
52
// Last value is packed in the upper 2 bits of the three bytes
268
53
unpacked[3 ] = ((packed[0 ] & 0b1100'0000u ) >> 6 ) |
269
- ((packed[1 ] & 0b1100'0000u ) >> 4 ) |
270
- ((packed[2 ] & 0b1100'0000u ) >> 2 );
54
+ ((packed[1 ] & 0b1100'0000u ) >> 4 ) | ((packed[2 ] & 0b1100'0000u ) >> 2 );
271
55
}
272
56
273
- TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2 (
274
- uint8_t * packed,
275
- const uint8x16_t & unpacked0,
276
- const uint8x16_t & unpacked1) {
57
+ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values (
58
+ uint8_t * packed,
59
+ const uint8x16_t & unpacked0,
60
+ const uint8x16_t & unpacked1) {
277
61
// This function is a vectorized version of pack_4_uint6_values_v2.
278
- // To understand the following code, please see pack_4_uint6_values_v2 first and
279
- // consider the following mapping for the unpacked parameter of that function:
62
+ // To understand the following code, please see pack_4_uint6_values_v2 first
63
+ // and consider the following mapping for the unpacked parameter of that
64
+ // function:
280
65
//
281
66
// unpacked[0] -> vget_low_u8(unpacked0)
282
67
// unpacked[1] -> vget_high_u8(unpacked0)
@@ -293,23 +78,26 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
293
78
// packed[0] = unpacked[0]
294
79
// packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
295
80
r = vget_low_u8 (unpacked0);
296
- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
81
+ r = vorr_u8 (
82
+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
297
83
vst1_u8 (packed, r);
298
84
299
85
// packed[1] = unpacked[1]
300
86
// packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
301
87
r = vget_high_u8 (unpacked0);
302
- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
88
+ r = vorr_u8 (
89
+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
303
90
vst1_u8 (packed + 8 , r);
304
91
305
92
// packed[2] = unpacked[2]
306
93
// packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
307
94
r = vget_low_u8 (unpacked1);
308
- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
95
+ r = vorr_u8 (
96
+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
309
97
vst1_u8 (packed + 16 , r);
310
98
}
311
99
312
- TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2 (
100
+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values (
313
101
uint8x16_t & unpacked0,
314
102
uint8x16_t & unpacked1,
315
103
const uint8_t * packed) {
@@ -331,18 +119,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
331
119
// ((packed[2] & 0b1100'0000u) >> 2);
332
120
const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
333
121
uint8x8_t unpacked3;
334
- unpacked3 = vorr_u8 (vshr_n_u8 ( vand_u8 (packed0, high), 6 ),
335
- vshr_n_u8 (vand_u8 (packed1 , high), 4 ));
336
- unpacked3 = vorr_u8 (unpacked3,
337
- vshr_n_u8 (vand_u8 (packed2, high), 2 ));
122
+ unpacked3 = vorr_u8 (
123
+ vshr_n_u8 (vand_u8 (packed0 , high), 6 ),
124
+ vshr_n_u8 ( vand_u8 (packed1, high), 4 ));
125
+ unpacked3 = vorr_u8 (unpacked3, vshr_n_u8 (vand_u8 (packed2, high), 2 ));
338
126
339
127
// unpacked[i] = packed[i] & 0b11'1111u;
340
128
const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
341
129
unpacked0 = vcombine_u8 (vand_u8 (packed0, mask), vand_u8 (packed1, mask));
342
130
unpacked1 = vcombine_u8 (vand_u8 (packed2, mask), unpacked3);
343
131
}
344
132
345
- TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2 (
133
+ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values (
346
134
uint8_t * packed,
347
135
const uint8x16_t & unpacked0,
348
136
const uint8x16_t & unpacked1,
@@ -376,7 +164,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
376
164
vst1q_u8 (packed + 32 , r);
377
165
}
378
166
379
- TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2 (
167
+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values (
380
168
uint8x16_t & unpacked0,
381
169
uint8x16_t & unpacked1,
382
170
uint8x16_t & unpacked2,
@@ -399,10 +187,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
399
187
// ((packed[1] & 0b1100'0000u) >> 4) |
400
188
// ((packed[2] & 0b1100'0000u) >> 2);
401
189
const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
402
- unpacked3 = vorrq_u8 (vshrq_n_u8 ( vandq_u8 (unpacked0, high), 6 ),
403
- vshrq_n_u8 (vandq_u8 (unpacked1 , high), 4 ));
404
- unpacked3 = vorrq_u8 (unpacked3,
405
- vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
190
+ unpacked3 = vorrq_u8 (
191
+ vshrq_n_u8 (vandq_u8 (unpacked0 , high), 6 ),
192
+ vshrq_n_u8 ( vandq_u8 (unpacked1, high), 4 ));
193
+ unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
406
194
407
195
// unpacked[i] = packed[i] & 0b11'1111u;
408
196
const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
0 commit comments