@@ -209,7 +209,7 @@ class Wint2xMmaMultistage :
209
209
WarpTransformedFragmentA warp_frag_A_[2 ];
210
210
211
211
// / Pair of B fragments used to overlap shared memory loads and math instructions
212
- WarpLoadedFragmentB warp_loaded_frag_B_;
212
+ WarpLoadedFragmentB warp_loaded_frag_B_[ 2 ] ;
213
213
WarpTransformedFragmentB warp_frag_B_;
214
214
};
215
215
@@ -691,10 +691,10 @@ class Wint2xMmaMultistage :
691
691
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB ;
692
692
int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB ;
693
693
694
- if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
694
+ if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
695
695
// Load the next warp-tile's B fragment from shared memory
696
696
this ->warp_tile_iterator_B_ .set_kgroup_index ((warp_mma_k_for_B + 1 ) % Base::kWarpGemmIterations );
697
- this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
697
+ this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ [(warp_mma_k_for_B + 1 ) % 2 ] );
698
698
++this ->warp_tile_iterator_B_ ;
699
699
700
700
warp_dequantizer_.load (pipe_state.warp_frag_local_scale_ );
@@ -718,6 +718,16 @@ class Wint2xMmaMultistage :
718
718
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719
719
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720
720
721
+ if (warp_k_compute_offset_B == 0 ) {
722
+ warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
723
+ pipe_state.warp_frag_code_scale_ ,
724
+ pipe_state.warp_frag_code_zp_ ,
725
+ pipe_state.warp_frag_super_scale_ ,
726
+ pipe_state.warp_loaded_frag_B_ [warp_mma_k_for_B % 2 ],
727
+ pipe_state.warp_frag_B_ ,
728
+ (stage - Base::kStages + 2 ) * Shape::kK );
729
+ }
730
+
721
731
if (Detail::kStagedAccumulation ) {
722
732
// CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723
733
warp_mma_ (
@@ -814,16 +824,6 @@ class Wint2xMmaMultistage :
814
824
iterator_B.clear_mask (gemm_k_iterations == 0 );
815
825
quant_params_accessor_B_.clear_mask (mma_quant_args, gemm_k_iterations == 0 );
816
826
}
817
-
818
- if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
819
- warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
820
- pipe_state.warp_frag_code_scale_ ,
821
- pipe_state.warp_frag_code_zp_ ,
822
- pipe_state.warp_frag_super_scale_ ,
823
- pipe_state.warp_loaded_frag_B_ ,
824
- pipe_state.warp_frag_B_ ,
825
- (stage - Base::kStages + 2 ) * Shape::kK );
826
- }
827
827
}
828
828
}
829
829
@@ -861,7 +861,7 @@ class Wint2xMmaMultistage :
861
861
862
862
// Load first warp-tile's B fragment from shared memory
863
863
this ->warp_tile_iterator_B_ .set_kgroup_index (0 );
864
- this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
864
+ this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ [ 0 ] );
865
865
++this ->warp_tile_iterator_B_ ;
866
866
867
867
#if 0
@@ -907,14 +907,6 @@ class Wint2xMmaMultistage :
907
907
}
908
908
#endif
909
909
910
- warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
911
- pipe_state.warp_frag_code_scale_ ,
912
- pipe_state.warp_frag_code_zp_ ,
913
- pipe_state.warp_frag_super_scale_ ,
914
- pipe_state.warp_loaded_frag_B_ ,
915
- pipe_state.warp_frag_B_ ,
916
- 0 );
917
-
918
910
#if 0
919
911
if (TransformBAfterLDS::result_type::kElements == 64) {
920
912
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
0 commit comments