Skip to content

Commit e7892ab

Browse files
seanx92facebook-github-bot
authored andcommitted
avx512 based int8 -> bf16 dequantization (#4912)
Summary: X-link: facebookresearch/FBGEMM#1949 Use AVX512-bf16 intrinsics for int8 -> bf16 dequantization Differential Revision: D82507938
1 parent 947e4e5 commit e7892ab

File tree

10 files changed

+153
-20
lines changed

10 files changed

+153
-20
lines changed

fbgemm_gpu/cmake/Fbgemm.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ set(fbgemm_sources_avx2
2626
"${FBGEMM}/src/QuantUtilsAvx2.cc")
2727

2828
set(fbgemm_sources_avx512
29-
"${FBGEMM}/src/EmbeddingSpMDMAvx512.cc")
29+
"${FBGEMM}/src/EmbeddingSpMDMAvx512.cc"
30+
"${FBGEMM}/src/QuantUtilsAvx512.cc")
3031

3132
if(CXX_AVX2_FOUND)
3233
set_source_files_properties(${fbgemm_sources_avx2}
@@ -46,7 +47,7 @@ if(CXX_AVX2_FOUND)
4647
${fbgemm_sources}
4748
${fbgemm_sources_avx2})
4849
endif()
49-
if((NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) AND CXX_AVX512_FOUND)
50+
if(CXX_AVX512_FOUND)
5051
set(fbgemm_sources
5152
${fbgemm_sources}
5253
${fbgemm_sources_avx2}

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ at::Tensor FP8rowwise_to_float_cpu(
411411
const bool forward = true,
412412
const int64_t output_dtype = 0);
413413
at::Tensor fused8bitrowwise_to_half_cpu(const at::Tensor& input);
414+
at::Tensor fused8bitrowwise_to_bfloat16_cpu(const at::Tensor& input);
414415
at::Tensor fused8bitrowwise_to_float_or_half_cpu(
415416
const at::Tensor& input,
416417
const int64_t output_dtype,
@@ -469,6 +470,9 @@ at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
469470
at::Tensor& _fused8bitrowwise_to_float_cpu_out(
470471
at::Tensor& output,
471472
const at::Tensor& input);
473+
at::Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
474+
at::Tensor& output,
475+
const at::Tensor& input);
472476
at::Tensor& _float_to_fused8bitrowwise_cpu_out(
473477
at::Tensor& output,
474478
const at::Tensor& input);

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
5555
return output;
5656
}
5757

58-
template <typename output_t>
58+
template <typename output_t, bool is_uint16_t_of_type_bf16 = false>
5959
Tensor& _fused8bitrowwise_to_float_cpu_out_t(
6060
Tensor& output,
6161
const Tensor& input) {
@@ -78,7 +78,9 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
7878
auto output_data = static_cast<output_t*>(
7979
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
8080
// unresolved data_ptr symbol.
81-
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<output_t>(
81+
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
82+
output_t,
83+
is_uint16_t_of_type_bf16>(
8284
input.data_ptr<uint8_t>(), nrows, ncols, output_data);
8385

8486
return output;
@@ -217,11 +219,19 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
217219
Tensor& _fused8bitrowwise_to_float_cpu_out(
218220
Tensor& output,
219221
const Tensor& input) {
220-
return _fused8bitrowwise_to_float_cpu_out_t<float>(output, input);
222+
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(output, input);
221223
}
222224

223225
Tensor& fused8bitrowwise_to_half_cpu_out(Tensor& output, const Tensor& input) {
224-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16>(output, input);
226+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16, false>(
227+
output, input);
228+
}
229+
230+
Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
231+
Tensor& output,
232+
const Tensor& input) {
233+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16, true>(
234+
output, input);
225235
}
226236

227237
/// @ingroup quantize-data-cpu
@@ -285,6 +295,13 @@ Tensor fused8bitrowwise_to_half_cpu(const Tensor& input) {
285295
return fused8bitrowwise_to_half_cpu_out(output, input);
286296
}
287297

298+
/// @ingroup quantize-data-cpu
299+
///
300+
Tensor fused8bitrowwise_to_bfloat16_cpu(const Tensor& input) {
301+
auto output = at::empty({0}, input.options().dtype(at::kBFloat16));
302+
return _fused8bitrowwise_to_bfloat16_cpu_out(output, input);
303+
}
304+
288305
/// @ingroup quantize-data-cpu
289306
///
290307
Tensor fused8bitrowwise_to_float_or_half_cpu(
@@ -305,6 +322,10 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
305322
output = at::empty({0}, input.options().dtype(at::kHalf));
306323
output = fused8bitrowwise_to_half_cpu_out(output, input);
307324
break;
325+
case SparseType::BF16:
326+
output = at::empty({0}, input.options().dtype(at::kBFloat16));
327+
output = _fused8bitrowwise_to_bfloat16_cpu_out(output, input);
328+
break;
308329
default:
309330
TORCH_CHECK(false);
310331
}
@@ -582,6 +603,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
582603
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor",
583604
{PT2_COMPLIANT_TAG});
584605
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
606+
m.def("Fused8BitRowwiseQuantizedToBfloat16(Tensor input) -> Tensor");
585607
m.def(
586608
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor");
587609
m.def(
@@ -648,6 +670,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
648670
DISPATCH_TO_CPU(
649671
"Fused8BitRowwiseQuantizedToHalf",
650672
fbgemm_gpu::fused8bitrowwise_to_half_cpu);
673+
DISPATCH_TO_CPU(
674+
"Fused8BitRowwiseQuantizedToBfloat16",
675+
fbgemm_gpu::fused8bitrowwise_to_bfloat16_cpu);
651676
DISPATCH_TO_CPU(
652677
"Fused8BitRowwiseQuantizedToFloatOrHalf",
653678
fbgemm_gpu::fused8bitrowwise_to_float_or_half_cpu);

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
141141

142142
assume(ncols % (2 * num_elem_per_byte) == 0)
143143
if not test_cuda:
144-
# cpu path does not support bf16
144+
# cpu path only supports bf16 dequantization
145145
if output_dtype == SparseType.BF16:
146-
return
146+
input_data = input_data.float()
147147
if test_generic_op:
148148
quantized_data = (
149149
torch.ops.fbgemm.FloatOrHalfToFused8BitRowwiseQuantized(input_data)
@@ -171,6 +171,15 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
171171
dequantized_data = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToHalf(
172172
quantized_data
173173
)
174+
elif output_dtype == SparseType.BF16:
175+
quantized_data = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
176+
input_data,
177+
)
178+
dequantized_data = (
179+
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToBfloat16(
180+
quantized_data,
181+
)
182+
)
174183
else:
175184
raise NotImplementedError("Unsupported dtype")
176185

@@ -185,6 +194,10 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
185194
torch.testing.assert_close(dequantized_data.float(), reference.float())
186195
elif output_dtype == SparseType.FP16:
187196
torch.testing.assert_close(dequantized_data.half(), reference.half())
197+
elif output_dtype == SparseType.BF16:
198+
torch.testing.assert_close(
199+
dequantized_data.bfloat16(), reference.bfloat16()
200+
)
188201
if test_cuda and gpu_available:
189202
if nrows == 0 or ncols == 0:
190203
return

include/fbgemm/QuantUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "./FbgemmBuild.h" // @manual
1212
#include "./QuantUtilsAvx2.h" // @manual
13+
#include "./QuantUtilsAvx512.h" // @manual
1314
#include "./QuantUtilsNeon.h" // @manual
1415
#include "./Types.h" // @manual
1516
#include "./Utils.h" // @manual
@@ -330,7 +331,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
330331
* This version intentionally supports only 8-bit because
331332
* the corresponding quantize version only supports 8-bit.
332333
*/
333-
template <typename OutputType>
334+
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
334335
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
335336
const uint8_t* input,
336337
size_t input_rows,
@@ -377,7 +378,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
377378
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
378379
* This should not be called directly except in testing.
379380
*/
380-
template <typename OutputType>
381+
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
381382
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
382383
const uint8_t* input,
383384
size_t input_rows,

include/fbgemm/QuantUtilsAvx512.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "Types.h"
1112
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
1213

1314
#include <cstdint>
@@ -37,6 +38,12 @@ FBGEMM_API void requantizeOutputProcessingGConvAvx512(
3738
int ld_out,
3839
int ld_in,
3940
const requantizationParams_t<BIAS_TYPE>& r);
41+
42+
void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
43+
const std::uint8_t* input,
44+
size_t input_rows,
45+
int input_columns,
46+
bfloat16* output);
4047
} // namespace fbgemm
4148

4249
#endif

include/fbgemm/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ FBGEMM_API bool fbgemmHasAvx2Support();
177177
*/
178178
FBGEMM_API bool fbgemmHasAvx512VnniSupport();
179179

180+
/**
181+
* @brief Are we running on a AVX512_BF16 supported cpu?
182+
*/
183+
FBGEMM_API bool fbgemmHasAvx512Bf16Support();
184+
180185
/**
181186
* @brief Are we running on a ARM Neon supported cpu?
182187
*/

src/QuantUtils.cc

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
806806
}
807807
}
808808

809-
template <typename OutputType>
809+
template <typename OutputType, bool is_uint16_t_of_type_bf16>
810810
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
811811
const std::uint8_t* input,
812812
size_t input_rows,
@@ -826,13 +826,17 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
826826
if constexpr (std::is_same<OutputType, float>()) {
827827
output_row[col] = output_value;
828828
} else {
829-
output_row[col] = cpu_float2half_rn(output_value);
829+
if constexpr (is_uint16_t_of_type_bf16) {
830+
output_row[col] = cpu_float2bfloat16(output_value);
831+
} else {
832+
output_row[col] = cpu_float2half_rn(output_value);
833+
}
830834
}
831835
}
832836
}
833837
}
834838

835-
template <typename OutputType>
839+
template <typename OutputType, bool is_uint16_t_of_type_bf16>
836840
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
837841
const std::uint8_t* input,
838842
size_t input_rows,
@@ -844,13 +848,23 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
844848
#else
845849
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
846850
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
847-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
848-
input, input_rows, input_columns, output);
851+
if (is_uint16_t_of_type_bf16 && fbgemmHasAvx512Bf16Support()) {
852+
Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
853+
input,
854+
input_rows,
855+
input_columns,
856+
reinterpret_cast<bfloat16*>(output));
857+
return;
858+
} else if (!is_uint16_t_of_type_bf16) {
859+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
860+
input, input_rows, input_columns, output);
861+
return;
862+
}
849863
#endif
850-
} else {
851-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
852-
input, input_rows, input_columns, output);
853864
}
865+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<
866+
OutputType,
867+
is_uint16_t_of_type_bf16>(input, input_rows, input_columns, output);
854868
#endif
855869
}
856870

@@ -906,13 +920,25 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
906920
std::uint8_t* output, \
907921
const type* rowwise_min_max); \
908922
template FBGEMM_API void \
909-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>( \
923+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type, false>( \
924+
const uint8_t* input, \
925+
size_t input_rows, \
926+
int input_columns, \
927+
type* output); \
928+
template FBGEMM_API void \
929+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type, true>( \
930+
const uint8_t* input, \
931+
size_t input_rows, \
932+
int input_columns, \
933+
type* output); \
934+
template FBGEMM_API void \
935+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type, false>( \
910936
const uint8_t* input, \
911937
size_t input_rows, \
912938
int input_columns, \
913939
type* output); \
914940
template FBGEMM_API void \
915-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>( \
941+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type, true>( \
916942
const uint8_t* input, \
917943
size_t input_rows, \
918944
int input_columns, \

src/QuantUtilsAvx512.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <stdexcept>
910
#define FBGEMM_EXPORTS
1011
#include "fbgemm/QuantUtilsAvx512.h"
1112
#if defined(__x86_64__) || defined(__i386__) || \
1213
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1314
#include <immintrin.h>
1415
#endif
16+
#include <fbgemm/FloatConversion.h>
1517
#include <cassert>
1618

1719
namespace fbgemm {
@@ -381,6 +383,50 @@ void requantizeOutputProcessingGConvAvx512(
381383
} // i loop
382384
}
383385

386+
void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
387+
const std::uint8_t* input,
388+
size_t input_rows,
389+
int input_columns,
390+
bfloat16* output) {
391+
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
392+
constexpr int VLEN = 8;
393+
int output_columns = input_columns - 2 * sizeof(float);
394+
395+
for (size_t row = 0; row < input_rows; ++row) {
396+
const std::uint8_t* input_row = input + row * input_columns;
397+
const float* input_row_scale_bias =
398+
reinterpret_cast<const float*>(input_row + output_columns);
399+
bfloat16* output_row = output + row * output_columns;
400+
401+
__m256 scale_v = _mm256_set1_ps(input_row_scale_bias[0]);
402+
__m256 bias_v = _mm256_set1_ps(input_row_scale_bias[1]);
403+
404+
int col = 0;
405+
for (col = 0; col < output_columns / VLEN * VLEN; col += VLEN) {
406+
__m256 in_v = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
407+
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(input_row + col))));
408+
#ifdef __FMA__
409+
__m256 dequantzed_v = _mm256_fmadd_ps(in_v, scale_v, bias_v);
410+
#else
411+
__m256 dequantzed_v = _mm256_add_ps(_mm256_mul_ps(in_v, scale_v), bias_v);
412+
#endif
413+
_mm_storeu_si128(
414+
reinterpret_cast<__m128i*>(output_row + col),
415+
(__m128i)(_mm256_cvtneps_pbh(dequantzed_v)));
416+
}
417+
418+
for (; col < output_columns; ++col) {
419+
float output_value =
420+
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
421+
output_row[col] = cpu_float2bfloat16(output_value);
422+
}
423+
} // for each row
424+
#else
425+
throw std::runtime_error(
426+
"Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512 not implemented for non x86");
427+
#endif
428+
}
429+
384430
#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
385431
A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
386432
template void requantizeOutputProcessingGConvAvx512< \
@@ -468,4 +514,5 @@ INSTANTIATE_BIAS(false)
468514
#undef INSTANTIATE_B_SYM
469515
#undef INSTANTIATE_Q_GRANS
470516
#undef INSTANTIATE_BIAS
517+
471518
} // namespace fbgemm

src/Utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ bool fbgemmHasAvx512VnniSupport() {
319319
return cpuinfo_has_x86_avx512vnni();
320320
}
321321

322+
bool fbgemmHasAvx512Bf16Support() {
323+
return cpuinfo_has_x86_avx512bf16();
324+
}
325+
322326
bool fbgemmHasArmNeonSupport() {
323327
return cpuinfo_has_arm_neon();
324328
}

0 commit comments

Comments
 (0)