@@ -220,10 +220,18 @@ Tensor& _fused8bitrowwise_to_float_cpu_out(
220
220
return _fused8bitrowwise_to_float_cpu_out_t <float >(output, input);
221
221
}
222
222
223
- Tensor& fused8bitrowwise_to_half_cpu_out (Tensor& output, const Tensor& input) {
223
+ static Tensor& fused8bitrowwise_to_half_cpu_out (
224
+ Tensor& output,
225
+ const Tensor& input) {
224
226
return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::float16>(output, input);
225
227
}
226
228
229
+ static Tensor& fused8bitrowwise_to_bfloat16_cpu_out (
230
+ Tensor& output,
231
+ const Tensor& input) {
232
+ return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::bfloat16>(output, input);
233
+ }
234
+
227
235
// / @ingroup quantize-data-cpu
228
236
// /
229
237
Tensor& _float_to_fused8bitrowwise_cpu_out (
@@ -232,7 +240,9 @@ Tensor& _float_to_fused8bitrowwise_cpu_out(
232
240
return _float_to_fused8bitrowwise_cpu_out_t <float >(output, input);
233
241
}
234
242
235
- Tensor& _half_to_fused8bitrowwise_cpu_out (Tensor& output, const Tensor& input) {
243
+ static Tensor& _half_to_fused8bitrowwise_cpu_out (
244
+ Tensor& output,
245
+ const Tensor& input) {
236
246
return _float_to_fused8bitrowwise_cpu_out_t <fbgemm::float16>(output, input);
237
247
}
238
248
@@ -285,6 +295,13 @@ Tensor fused8bitrowwise_to_half_cpu(const Tensor& input) {
285
295
return fused8bitrowwise_to_half_cpu_out (output, input);
286
296
}
287
297
298
+ // / @ingroup quantize-data-cpu
299
+ // /
300
+ Tensor fused8bitrowwise_to_bfloat16_cpu (const Tensor& input) {
301
+ auto output = at::empty ({0 }, input.options ().dtype (at::kBFloat16 ));
302
+ return fused8bitrowwise_to_bfloat16_cpu_out (output, input);
303
+ }
304
+
288
305
// / @ingroup quantize-data-cpu
289
306
// /
290
307
Tensor fused8bitrowwise_to_float_or_half_cpu (
@@ -305,6 +322,10 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
305
322
output = at::empty ({0 }, input.options ().dtype (at::kHalf ));
306
323
output = fused8bitrowwise_to_half_cpu_out (output, input);
307
324
break ;
325
+ case SparseType::BF16:
326
+ output = at::empty ({0 }, input.options ().dtype (at::kBFloat16 ));
327
+ output = fused8bitrowwise_to_bfloat16_cpu_out (output, input);
328
+ break ;
308
329
default :
309
330
TORCH_CHECK (false );
310
331
}
@@ -582,6 +603,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
582
603
" FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor" ,
583
604
{PT2_COMPLIANT_TAG});
584
605
m.def (" Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor" );
606
+ m.def (" Fused8BitRowwiseQuantizedToBfloat16(Tensor input) -> Tensor" );
585
607
m.def (
586
608
" Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor" );
587
609
m.def (
@@ -648,6 +670,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
648
670
DISPATCH_TO_CPU (
649
671
" Fused8BitRowwiseQuantizedToHalf" ,
650
672
fbgemm_gpu::fused8bitrowwise_to_half_cpu);
673
+ DISPATCH_TO_CPU (
674
+ " Fused8BitRowwiseQuantizedToBfloat16" ,
675
+ fbgemm_gpu::fused8bitrowwise_to_bfloat16_cpu);
651
676
DISPATCH_TO_CPU (
652
677
" Fused8BitRowwiseQuantizedToFloatOrHalf" ,
653
678
fbgemm_gpu::fused8bitrowwise_to_float_or_half_cpu);
0 commit comments