@@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
234
234
unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (b3210, 4 ));
235
235
}
236
236
237
+ TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2 (
238
+ uint8_t * packed,
239
+ const uint8_t * unpacked) {
240
+ // Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
241
+ // this function packs them as:
242
+ // packed[0]: 56 | abcdef
243
+ // packed[1]: 34 | ghijkl
244
+ // packed[2]: 12 | mnopqr
245
+ //
246
+ // Input is 4 bytes
247
+ // Output is 6 * 4 bits/8 = 3 bytes
248
+ packed[0 ] = unpacked[0 ];
249
+ packed[1 ] = unpacked[1 ];
250
+ packed[2 ] = unpacked[2 ];
251
+ // Last value is packed in the upper 2 bits of the three bytes
252
+ packed[0 ] |= ((unpacked[3 ] & 0b00'0011u ) << 6 );
253
+ packed[1 ] |= ((unpacked[3 ] & 0b00'1100u ) << 4 );
254
+ packed[2 ] |= ((unpacked[3 ] & 0b11'0000u ) << 2 );
255
+ }
256
+
257
+ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2 (
258
+ uint8_t * unpacked,
259
+ const uint8_t * packed) {
260
+ // Unpacks data packed by pack_4_uint6_values_v2
261
+ //
262
+ // Input is 24 bits = 3 bytes
263
+ // Output is 4 bytes
264
+ unpacked[0 ] = packed[0 ] & 0b111111u ;
265
+ unpacked[1 ] = packed[1 ] & 0b111111u ;
266
+ unpacked[2 ] = packed[2 ] & 0b111111u ;
267
+ // Last value is packed in the upper 2 bits of the three bytes
268
+ unpacked[3 ] = ((packed[0 ] & 0b1100'0000u ) >> 6 ) |
269
+ ((packed[1 ] & 0b1100'0000u ) >> 4 ) |
270
+ ((packed[2 ] & 0b1100'0000u ) >> 2 );
271
+ }
272
+
273
+ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2 (
274
+ uint8_t * packed,
275
+ const uint8x16_t & unpacked0,
276
+ const uint8x16_t & unpacked1) {
277
+ // This function is a vectorized version of pack_4_uint6_values_v2.
278
+ // To understand the following code, please see pack_4_uint6_values_v2 first and
279
+ // consider the following mapping for the unpacked parameter of that function:
280
+ //
281
+ // unpacked[0] -> vget_low_u8(unpacked0)
282
+ // unpacked[1] -> vget_high_u8(unpacked0)
283
+ // unpacked[2] -> vget_low_u8(unpacked1)
284
+ // unpacked[3] -> vget_high_u8(unpacked1)
285
+ //
286
+ // Before each code section, there is a comment indicating the
287
+ // code in pack_4_uint6_values_v2 that is being vectorized.
288
+ //
289
+ // Input is 32 bytes.
290
+ // Output is 6*32= 192 bits = 24 bytes.
291
+ uint8x8_t r;
292
+
293
+ // packed[0] = unpacked[0]
294
+ // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
295
+ r = vget_low_u8 (unpacked0);
296
+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
297
+ vst1_u8 (packed, r);
298
+
299
+ // packed[1] = unpacked[1]
300
+ // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
301
+ r = vget_high_u8 (unpacked0);
302
+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
303
+ vst1_u8 (packed + 8 , r);
304
+
305
+ // packed[2] = unpacked[2]
306
+ // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
307
+ r = vget_low_u8 (unpacked1);
308
+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
309
+ vst1_u8 (packed + 16 , r);
310
+ }
311
+
312
+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2 (
313
+ uint8x16_t & unpacked0,
314
+ uint8x16_t & unpacked1,
315
+ const uint8_t * packed) {
316
+ // Unpacks data packed by vec_pack_32_uint6_values_v2.
317
+ //
318
+ // This function vectorizes unpack_4_uint6_values_v2.
319
+ // To understand it, please see unpack_4_uint6_values_v2 first.
320
+ // Before each code section, there is a comment indicating the
321
+ // code in unpack_4_uint6_values_v2 that is being vectorized.
322
+ //
323
+ // Input is 24 bytes.
324
+ // Output is 32 bytes.
325
+ uint8x8_t packed0 = vld1_u8 (packed);
326
+ uint8x8_t packed1 = vld1_u8 (packed + 8 );
327
+ uint8x8_t packed2 = vld1_u8 (packed + 16 );
328
+
329
+ // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
330
+ // ((packed[1] & 0b1100'0000u) >> 4) |
331
+ // ((packed[2] & 0b1100'0000u) >> 2);
332
+ const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
333
+ uint8x8_t unpacked3;
334
+ unpacked3 = vorr_u8 (vshr_n_u8 (vand_u8 (packed0, high), 6 ),
335
+ vshr_n_u8 (vand_u8 (packed1, high), 4 ));
336
+ unpacked3 = vorr_u8 (unpacked3,
337
+ vshr_n_u8 (vand_u8 (packed2, high), 2 ));
338
+
339
+ // unpacked[i] = packed[i] & 0b11'1111u;
340
+ const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
341
+ unpacked0 = vcombine_u8 (vand_u8 (packed0, mask), vand_u8 (packed1, mask));
342
+ unpacked1 = vcombine_u8 (vand_u8 (packed2, mask), unpacked3);
343
+ }
344
+
345
+ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2 (
346
+ uint8_t * packed,
347
+ const uint8x16_t & unpacked0,
348
+ const uint8x16_t & unpacked1,
349
+ const uint8x16_t & unpacked2,
350
+ const uint8x16_t & unpacked3) {
351
+ // This function is a vectorized version of pack_4_uint6_values_v2.
352
+ // To understand the following code, please see pack_4_uint6_values_v2 first.
353
+ // Before each code section, there is a comment indicating the
354
+ // code in pack_4_uint6_values_v2 that is being vectorized.
355
+ //
356
+ // Input is 48 bytes.
357
+ // Output is 64 bytes.
358
+ uint8x16_t r;
359
+
360
+ // packed[0] = unpacked[0]
361
+ // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
362
+ r = unpacked0;
363
+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b00'0011u )), 6 ));
364
+ vst1q_u8 (packed, r);
365
+
366
+ // packed[1] = unpacked[1]
367
+ // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
368
+ r = unpacked1;
369
+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b00'1100u )), 4 ));
370
+ vst1q_u8 (packed + 16 , r);
371
+
372
+ // packed[2] = unpacked[2]
373
+ // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
374
+ r = unpacked2;
375
+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b11'0000u )), 2 ));
376
+ vst1q_u8 (packed + 32 , r);
377
+ }
378
+
379
+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2 (
380
+ uint8x16_t & unpacked0,
381
+ uint8x16_t & unpacked1,
382
+ uint8x16_t & unpacked2,
383
+ uint8x16_t & unpacked3,
384
+ const uint8_t * packed) {
385
+ // Unpacks data packed by vec_pack_64_uint6_values_v2.
386
+ //
387
+ // This function vectorizes unpack_4_uint6_values_v2.
388
+ // To understand it, please see unpack_4_uint6_values_v2 first.
389
+ // Before each code section, there is a comment indicating the
390
+ // code in unpack_4_uint6_values that is being vectorized
391
+
392
+ // Input is 48 bytes.
393
+ // Output is 64 bytes.
394
+ unpacked0 = vld1q_u8 (packed);
395
+ unpacked1 = vld1q_u8 (packed + 16 );
396
+ unpacked2 = vld1q_u8 (packed + 32 );
397
+
398
+ // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
399
+ // ((packed[1] & 0b1100'0000u) >> 4) |
400
+ // ((packed[2] & 0b1100'0000u) >> 2);
401
+ const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
402
+ unpacked3 = vorrq_u8 (vshrq_n_u8 (vandq_u8 (unpacked0, high), 6 ),
403
+ vshrq_n_u8 (vandq_u8 (unpacked1, high), 4 ));
404
+ unpacked3 = vorrq_u8 (unpacked3,
405
+ vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
406
+
407
+ // unpacked[i] = packed[i] & 0b11'1111u;
408
+ const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
409
+ unpacked0 = vandq_u8 (unpacked0, mask);
410
+ unpacked1 = vandq_u8 (unpacked1, mask);
411
+ unpacked2 = vandq_u8 (unpacked2, mask);
412
+ }
413
+
237
414
} // namespace internal
238
415
} // namespace bitpacking
239
416
} // namespace torchao
0 commit comments