@@ -55,7 +55,7 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
55
55
return output;
56
56
}
57
57
58
- template <typename output_t >
58
+ template <typename output_t , bool is_uint16_t_of_type_bf16 = false >
59
59
Tensor& _fused8bitrowwise_to_float_cpu_out_t (
60
60
Tensor& output,
61
61
const Tensor& input) {
@@ -78,7 +78,9 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
78
78
auto output_data = static_cast <output_t *>(
79
79
output.data_ptr ()); // output.data_ptr<output_t>(); -> Yields
80
80
// unresolved data_ptr symbol.
81
- fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<output_t >(
81
+ fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
82
+ output_t ,
83
+ is_uint16_t_of_type_bf16>(
82
84
input.data_ptr <uint8_t >(), nrows, ncols, output_data);
83
85
84
86
return output;
@@ -217,11 +219,19 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
217
219
Tensor& _fused8bitrowwise_to_float_cpu_out (
218
220
Tensor& output,
219
221
const Tensor& input) {
220
- return _fused8bitrowwise_to_float_cpu_out_t <float >(output, input);
222
+ return _fused8bitrowwise_to_float_cpu_out_t <float , false >(output, input);
221
223
}
222
224
223
225
Tensor& fused8bitrowwise_to_half_cpu_out (Tensor& output, const Tensor& input) {
224
- return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::float16>(output, input);
226
+ return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::float16, false >(
227
+ output, input);
228
+ }
229
+
230
+ Tensor& _fused8bitrowwise_to_bfloat16_cpu_out (
231
+ Tensor& output,
232
+ const Tensor& input) {
233
+ return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::bfloat16, true >(
234
+ output, input);
225
235
}
226
236
227
237
// / @ingroup quantize-data-cpu
@@ -285,6 +295,13 @@ Tensor fused8bitrowwise_to_half_cpu(const Tensor& input) {
285
295
return fused8bitrowwise_to_half_cpu_out (output, input);
286
296
}
287
297
298
+ // / @ingroup quantize-data-cpu
299
+ // /
300
+ Tensor fused8bitrowwise_to_bfloat16_cpu (const Tensor& input) {
301
+ auto output = at::empty ({0 }, input.options ().dtype (at::kBFloat16 ));
302
+ return _fused8bitrowwise_to_bfloat16_cpu_out (output, input);
303
+ }
304
+
288
305
// / @ingroup quantize-data-cpu
289
306
// /
290
307
Tensor fused8bitrowwise_to_float_or_half_cpu (
@@ -305,6 +322,10 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
305
322
output = at::empty ({0 }, input.options ().dtype (at::kHalf ));
306
323
output = fused8bitrowwise_to_half_cpu_out (output, input);
307
324
break ;
325
+ case SparseType::BF16:
326
+ output = at::empty ({0 }, input.options ().dtype (at::kBFloat16 ));
327
+ output = _fused8bitrowwise_to_bfloat16_cpu_out (output, input);
328
+ break ;
308
329
default :
309
330
TORCH_CHECK (false );
310
331
}
@@ -582,6 +603,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
582
603
" FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor" ,
583
604
{PT2_COMPLIANT_TAG});
584
605
m.def (" Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor" );
606
+ m.def (" Fused8BitRowwiseQuantizedToBfloat16(Tensor input) -> Tensor" );
585
607
m.def (
586
608
" Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor" );
587
609
m.def (
@@ -648,6 +670,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
648
670
DISPATCH_TO_CPU (
649
671
" Fused8BitRowwiseQuantizedToHalf" ,
650
672
fbgemm_gpu::fused8bitrowwise_to_half_cpu);
673
+ DISPATCH_TO_CPU (
674
+ " Fused8BitRowwiseQuantizedToBfloat16" ,
675
+ fbgemm_gpu::fused8bitrowwise_to_bfloat16_cpu);
651
676
DISPATCH_TO_CPU (
652
677
" Fused8BitRowwiseQuantizedToFloatOrHalf" ,
653
678
fbgemm_gpu::fused8bitrowwise_to_float_or_half_cpu);
0 commit comments