Skip to content

Commit fbfd59e

Browse files
seanx92facebook-github-bot
authored andcommitted
avx512 based int8 -> bf16 dequantization (#4912)
Summary: Use AVX512-bf16 intrinsics for int8 -> bf16 dequantization Differential Revision: D82507938
1 parent be84b43 commit fbfd59e

File tree

8 files changed

+125
-11
lines changed

8 files changed

+125
-11
lines changed

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,18 @@ Tensor& _fused8bitrowwise_to_float_cpu_out(
220220
return _fused8bitrowwise_to_float_cpu_out_t<float>(output, input);
221221
}
222222

223-
Tensor& fused8bitrowwise_to_half_cpu_out(Tensor& output, const Tensor& input) {
223+
static Tensor& fused8bitrowwise_to_half_cpu_out(
224+
Tensor& output,
225+
const Tensor& input) {
224226
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16>(output, input);
225227
}
226228

229+
static Tensor& fused8bitrowwise_to_bfloat16_cpu_out(
230+
Tensor& output,
231+
const Tensor& input) {
232+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16>(output, input);
233+
}
234+
227235
/// @ingroup quantize-data-cpu
228236
///
229237
Tensor& _float_to_fused8bitrowwise_cpu_out(
@@ -232,7 +240,9 @@ Tensor& _float_to_fused8bitrowwise_cpu_out(
232240
return _float_to_fused8bitrowwise_cpu_out_t<float>(output, input);
233241
}
234242

235-
Tensor& _half_to_fused8bitrowwise_cpu_out(Tensor& output, const Tensor& input) {
243+
static Tensor& _half_to_fused8bitrowwise_cpu_out(
244+
Tensor& output,
245+
const Tensor& input) {
236246
return _float_to_fused8bitrowwise_cpu_out_t<fbgemm::float16>(output, input);
237247
}
238248

@@ -285,6 +295,13 @@ Tensor fused8bitrowwise_to_half_cpu(const Tensor& input) {
285295
return fused8bitrowwise_to_half_cpu_out(output, input);
286296
}
287297

298+
/// @ingroup quantize-data-cpu
299+
///
300+
Tensor fused8bitrowwise_to_bfloat16_cpu(const Tensor& input) {
301+
auto output = at::empty({0}, input.options().dtype(at::kBFloat16));
302+
return fused8bitrowwise_to_bfloat16_cpu_out(output, input);
303+
}
304+
288305
/// @ingroup quantize-data-cpu
289306
///
290307
Tensor fused8bitrowwise_to_float_or_half_cpu(
@@ -305,6 +322,10 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
305322
output = at::empty({0}, input.options().dtype(at::kHalf));
306323
output = fused8bitrowwise_to_half_cpu_out(output, input);
307324
break;
325+
case SparseType::BF16:
326+
output = at::empty({0}, input.options().dtype(at::kBFloat16));
327+
output = fused8bitrowwise_to_bfloat16_cpu_out(output, input);
328+
break;
308329
default:
309330
TORCH_CHECK(false);
310331
}
@@ -582,6 +603,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
582603
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor",
583604
{PT2_COMPLIANT_TAG});
584605
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
606+
m.def("Fused8BitRowwiseQuantizedToBfloat16(Tensor input) -> Tensor");
585607
m.def(
586608
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor");
587609
m.def(
@@ -648,6 +670,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
648670
DISPATCH_TO_CPU(
649671
"Fused8BitRowwiseQuantizedToHalf",
650672
fbgemm_gpu::fused8bitrowwise_to_half_cpu);
673+
DISPATCH_TO_CPU(
674+
"Fused8BitRowwiseQuantizedToBfloat16",
675+
fbgemm_gpu::fused8bitrowwise_to_bfloat16_cpu);
651676
DISPATCH_TO_CPU(
652677
"Fused8BitRowwiseQuantizedToFloatOrHalf",
653678
fbgemm_gpu::fused8bitrowwise_to_float_or_half_cpu);

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
141141

142142
assume(ncols % (2 * num_elem_per_byte) == 0)
143143
if not test_cuda:
144-
# cpu path does not support bf16
144+
# cpu path only supports bf16 dequantization
145145
if output_dtype == SparseType.BF16:
146-
return
146+
input_data = input_data.float()
147147
if test_generic_op:
148148
quantized_data = (
149149
torch.ops.fbgemm.FloatOrHalfToFused8BitRowwiseQuantized(input_data)
@@ -171,6 +171,15 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
171171
dequantized_data = torch.ops.fbgemm.Fused8BitRowwiseQuantizedToHalf(
172172
quantized_data
173173
)
174+
elif output_dtype == SparseType.BF16:
175+
quantized_data = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
176+
input_data,
177+
)
178+
dequantized_data = (
179+
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToBfloat16(
180+
quantized_data,
181+
)
182+
)
174183
else:
175184
raise NotImplementedError("Unsupported dtype")
176185

@@ -185,6 +194,10 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
185194
torch.testing.assert_close(dequantized_data.float(), reference.float())
186195
elif output_dtype == SparseType.FP16:
187196
torch.testing.assert_close(dequantized_data.half(), reference.half())
197+
elif output_dtype == SparseType.BF16:
198+
torch.testing.assert_close(
199+
dequantized_data.bfloat16(), reference.bfloat16()
200+
)
188201
if test_cuda and gpu_available:
189202
if nrows == 0 or ncols == 0:
190203
return

include/fbgemm/QuantUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "./FbgemmBuild.h" // @manual
1212
#include "./QuantUtilsAvx2.h" // @manual
13+
#include "./QuantUtilsAvx512.h" // @manual
1314
#include "./QuantUtilsNeon.h" // @manual
1415
#include "./Types.h" // @manual
1516
#include "./Utils.h" // @manual

include/fbgemm/QuantUtilsAvx512.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ FBGEMM_API void requantizeOutputProcessingGConvAvx512(
3737
int ld_out,
3838
int ld_in,
3939
const requantizationParams_t<BIAS_TYPE>& r);
40+
41+
template <typename OutputType>
42+
void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
43+
const std::uint8_t* input,
44+
size_t input_rows,
45+
int input_columns,
46+
OutputType* output);
4047
} // namespace fbgemm
4148

4249
#endif

include/fbgemm/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ FBGEMM_API bool fbgemmHasAvx2Support();
177177
*/
178178
FBGEMM_API bool fbgemmHasAvx512VnniSupport();
179179

180+
/**
181+
* @brief Are we running on a AVX512_BF16 supported cpu?
182+
*/
183+
FBGEMM_API bool fbgemmHasAvx512Bf16Support();
184+
180185
/**
181186
* @brief Are we running on a ARM Neon supported cpu?
182187
*/

src/QuantUtils.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
825825
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
826826
if constexpr (std::is_same<OutputType, float>()) {
827827
output_row[col] = output_value;
828+
} else if constexpr (std::is_same_v<OutputType, bfloat16>) {
829+
output_row[col] = cpu_float2bfloat16(output_value);
828830
} else {
829831
output_row[col] = cpu_float2half_rn(output_value);
830832
}
@@ -842,15 +844,24 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
842844
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon<OutputType>(
843845
input, input_rows, input_columns, output);
844846
#else
845-
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
846847
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
847-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
848-
input, input_rows, input_columns, output);
849-
#endif
850-
} else {
851-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
852-
input, input_rows, input_columns, output);
848+
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
849+
if (fbgemmHasAvx512Bf16Support() && std::is_same_v<OutputType, bfloat16>) {
850+
// Avx512 bfloat16 native support
851+
Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512<OutputType>(
852+
input, input_rows, input_columns, output);
853+
return;
854+
} else if (!std::is_same_v<OutputType, bfloat16>) {
855+
// Avx2 does not support bfloat16
856+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
857+
input, input_rows, input_columns, output);
858+
return;
859+
}
853860
}
861+
#endif
862+
// Fallback to ref kernel
863+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
864+
input, input_rows, input_columns, output);
854865
#endif
855866
}
856867

src/QuantUtilsAvx512.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1313
#include <immintrin.h>
1414
#endif
15+
#include <fbgemm/FloatConversion.h>
1516
#include <cassert>
1617

1718
namespace fbgemm {
@@ -381,6 +382,46 @@ void requantizeOutputProcessingGConvAvx512(
381382
} // i loop
382383
}
383384

385+
template <typename OutputType>
386+
void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
387+
const std::uint8_t* input,
388+
size_t input_rows,
389+
int input_columns,
390+
OutputType* output) {
391+
constexpr int VLEN = 8;
392+
int output_columns = input_columns - 2 * sizeof(float);
393+
394+
for (size_t row = 0; row < input_rows; ++row) {
395+
const std::uint8_t* input_row = input + row * input_columns;
396+
const float* input_row_scale_bias =
397+
reinterpret_cast<const float*>(input_row + output_columns);
398+
OutputType* output_row = output + row * output_columns;
399+
400+
__m256 scale_v = _mm256_set1_ps(input_row_scale_bias[0]);
401+
__m256 bias_v = _mm256_set1_ps(input_row_scale_bias[1]);
402+
403+
int col = 0;
404+
for (col = 0; col < output_columns / VLEN * VLEN; col += VLEN) {
405+
__m256 in_v = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
406+
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(input_row + col))));
407+
#ifdef __FMA__
408+
__m256 dequantzed_v = _mm256_fmadd_ps(in_v, scale_v, bias_v);
409+
#else
410+
__m256 dequantzed_v = _mm256_add_ps(_mm256_mul_ps(in_v, scale_v), bias_v);
411+
#endif
412+
_mm_storeu_si128(
413+
reinterpret_cast<__m128i*>(output_row + col),
414+
_mm256_cvtneps_pbh(dequantzed_v));
415+
}
416+
417+
for (; col < output_columns; ++col) {
418+
float output_value =
419+
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
420+
output_row[col] = cpu_float2bfloat16(output_value);
421+
}
422+
} // for each row
423+
}
424+
384425
#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
385426
A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
386427
template void requantizeOutputProcessingGConvAvx512< \
@@ -468,4 +509,11 @@ INSTANTIATE_BIAS(false)
468509
#undef INSTANTIATE_B_SYM
469510
#undef INSTANTIATE_Q_GRANS
470511
#undef INSTANTIATE_BIAS
512+
513+
template void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512<bfloat16>(
514+
const std::uint8_t* input,
515+
size_t input_rows,
516+
int input_columns,
517+
bfloat16* output);
518+
471519
} // namespace fbgemm

src/Utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ bool fbgemmHasAvx512VnniSupport() {
319319
return cpuinfo_has_x86_avx512vnni();
320320
}
321321

322+
bool fbgemmHasAvx512Bf16Support() {
323+
return cpuinfo_has_x86_avx512bf16();
324+
}
325+
322326
bool fbgemmHasArmNeonSupport() {
323327
return cpuinfo_has_arm_neon();
324328
}

0 commit comments

Comments
 (0)