@@ -581,6 +581,26 @@ at::Tensor f8i4bf16_rowwise_meta(
581
581
return Y;
582
582
}
583
583
584
+ std::tuple<at::Tensor, at::Tensor> preshuffle_i4_meta (
585
+ at::Tensor WQ,
586
+ at::Tensor w_scale) {
587
+ return {
588
+ at::empty_like (WQ),
589
+ at::empty ({w_scale.size (0 ), 8 , w_scale.size (1 )}, w_scale.options ())};
590
+ }
591
+
592
+ at::Tensor f8i4bf16_shuffled_meta (
593
+ at::Tensor XQ, // FP8
594
+ at::Tensor WQ, // INT4
595
+ at::Tensor /* x_scale */ ,
596
+ at::Tensor /* w_scale */ ,
597
+ at::Tensor /* w_scale_group */ ) {
598
+ const at::SymInt M = XQ.sym_size (0 );
599
+ const at::SymInt N = WQ.sym_size (0 );
600
+ auto Y = at::empty_symint ({M, N}, XQ.options ().dtype (at::kBFloat16 ));
601
+ return Y;
602
+ }
603
+
584
604
at::Tensor bf16i4bf16_rowwise_meta (
585
605
at::Tensor X, // BF16
586
606
at::Tensor W, // INT4
@@ -702,6 +722,8 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
702
722
m.impl (" bf16i4bf16_rowwise_batched" , bf16i4bf16_rowwise_batched_meta);
703
723
m.impl (" f8f8bf16_lite" , f8f8bf16_lite_meta);
704
724
m.impl (" scaled_fp4_quant" , scaled_fp4_quant_meta);
725
+ m.impl (" preshuffle_i4" , preshuffle_i4_meta);
726
+ m.impl (" f8i4bf16_shuffled" , f8i4bf16_shuffled_meta);
705
727
#endif
706
728
#ifdef USE_ROCM
707
729
m.impl (" f8f8f16_rowwise" , f8f8f16_rowwise_meta);
0 commit comments