Skip to content

Commit 6653b45

Browse files
authored
Add quantized embedding kernels to torchao
Differential Revision: D63839255 Pull Request resolved: #1018
1 parent 7aaf0ff commit 6653b45

File tree

13 files changed

+609
-18
lines changed

13 files changed

+609
-18
lines changed

torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
1312
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
1413
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
1514
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
1615
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
1716
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
1817
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
18+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1919
#include <cassert>
2020

2121
namespace torchao {
@@ -142,7 +142,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
142142
break;
143143
case 6:
144144
torchao::bitpacking::internal::vec_pack_32_uint6_values(
145-
packed, shifted0, shifted1);
145+
packed, shifted0, shifted1);
146146
break;
147147
default:
148148
assert(false);
@@ -153,7 +153,7 @@ template <int nbit>
153153
TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
154154
int8x16_t& unpacked0,
155155
int8x16_t& unpacked1,
156-
uint8_t* packed) {
156+
const uint8_t* packed) {
157157
static_assert(nbit < 8);
158158
static_assert(nbit >= 1);
159159

@@ -217,7 +217,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
217217
break;
218218
case 6:
219219
torchao::bitpacking::internal::vec_unpack_32_uint6_values(
220-
shifted0, shifted1, packed);
220+
shifted0, shifted1, packed);
221221
break;
222222
default:
223223
assert(false);
@@ -288,7 +288,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
288288
int8x16_t& unpacked1,
289289
int8x16_t& unpacked2,
290290
int8x16_t& unpacked3,
291-
uint8_t* packed) {
291+
const uint8_t* packed) {
292292
static_assert(nbit < 8);
293293
static_assert(nbit >= 1);
294294

@@ -443,7 +443,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
443443
int8x16_t& unpacked5,
444444
int8x16_t& unpacked6,
445445
int8x16_t& unpacked7,
446-
uint8_t* packed) {
446+
const uint8_t* packed) {
447447
static_assert(nbit < 8);
448448
static_assert(nbit >= 1);
449449

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010
#include <arm_neon.h>
11-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1212

1313
// This file contains bitpacking and unpacking methods for uint1.
1414
// These are not inteded to be used outside of bitpacking directory.

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1313

1414
// This file contains bitpacking and unpacking methods for uint4.
1515
// These are not inteded to be used outside of bitpacking directory.

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1313

1414
// This file contains bitpacking and unpacking methods for uint3.
1515
// These are not inteded to be used outside of bitpacking directory.

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1313

1414
// This file contains bitpacking and unpacking methods for uint4.
1515
// These are not inteded to be used outside of bitpacking directory.

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1313

1414
// This file contains bitpacking and unpacking methods for uint5.
1515
// These are not inteded to be used outside of bitpacking directory.

torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12-
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1313

1414
// This file contains bitpacking and unpacking methods for uint5.
1515
// These are not inteded to be used outside of bitpacking directory.

0 commit comments

Comments
 (0)