@@ -588,6 +588,26 @@ at::Tensor f8i4bf16_rowwise_meta(
588
588
return Y;
589
589
}
590
590
591
+ std::tuple<at::Tensor, at::Tensor> preshuffle_i4_meta (
592
+ at::Tensor WQ,
593
+ at::Tensor w_scale) {
594
+ return {
595
+ at::empty_like (WQ),
596
+ at::empty ({w_scale.size (0 ), 8 , w_scale.size (1 )}, w_scale.options ())};
597
+ }
598
+
599
+ at::Tensor f8i4bf16_shuffled_meta (
600
+ at::Tensor XQ, // FP8
601
+ at::Tensor WQ, // INT4
602
+ at::Tensor /* x_scale */ ,
603
+ at::Tensor /* w_scale */ ,
604
+ at::Tensor /* w_scale_group */ ) {
605
+ const at::SymInt M = XQ.sym_size (0 );
606
+ const at::SymInt N = WQ.sym_size (0 );
607
+ auto Y = at::empty_symint ({M, N}, XQ.options ().dtype (at::kBFloat16 ));
608
+ return Y;
609
+ }
610
+
591
611
at::Tensor bf16i4bf16_rowwise_meta (
592
612
at::Tensor X, // BF16
593
613
at::Tensor W, // INT4
@@ -723,6 +743,8 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
723
743
m.impl (" bf16i4bf16_rowwise_batched" , bf16i4bf16_rowwise_batched_meta);
724
744
m.impl (" f8f8bf16_lite" , f8f8bf16_lite_meta);
725
745
m.impl (" scaled_fp4_quant" , scaled_fp4_quant_meta);
746
+ m.impl (" preshuffle_i4" , preshuffle_i4_meta);
747
+ m.impl (" f8i4bf16_shuffled" , f8i4bf16_shuffled_meta);
726
748
#endif
727
749
#ifdef USE_ROCM
728
750
m.impl (" f8f8f16_rowwise" , f8f8f16_rowwise_meta);
0 commit comments