Skip to content

Commit a29ff41

Browse files
committed
add pingpong buffer for b_frag
1 parent 81c704b commit a29ff41

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class Wint2xMmaMultistage :
209209
WarpTransformedFragmentA warp_frag_A_[2];
210210

211211
/// Pair of B fragments used to overlap shared memory loads and math instructions
212-
WarpLoadedFragmentB warp_loaded_frag_B_;
212+
WarpLoadedFragmentB warp_loaded_frag_B_[2];
213213
WarpTransformedFragmentB warp_frag_B_;
214214
};
215215

@@ -691,10 +691,10 @@ class Wint2xMmaMultistage :
691691
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
692692
int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB;
693693

694-
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
694+
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
695695
// Load the next warp-tile's B fragment from shared memory
696696
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations);
697-
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
697+
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k_for_B + 1) % 2]);
698698
++this->warp_tile_iterator_B_;
699699

700700
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
@@ -718,6 +718,16 @@ class Wint2xMmaMultistage :
718718
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719719
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720720

721+
if (warp_k_compute_offset_B == 0) {
722+
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
723+
pipe_state.warp_frag_code_scale_,
724+
pipe_state.warp_frag_code_zp_,
725+
pipe_state.warp_frag_super_scale_,
726+
pipe_state.warp_loaded_frag_B_[warp_mma_k_for_B % 2],
727+
pipe_state.warp_frag_B_,
728+
(stage - Base::kStages + 2) * Shape::kK);
729+
}
730+
721731
if (Detail::kStagedAccumulation) {
722732
//CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723733
warp_mma_(
@@ -814,16 +824,6 @@ class Wint2xMmaMultistage :
814824
iterator_B.clear_mask(gemm_k_iterations == 0);
815825
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
816826
}
817-
818-
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
819-
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
820-
pipe_state.warp_frag_code_scale_,
821-
pipe_state.warp_frag_code_zp_,
822-
pipe_state.warp_frag_super_scale_,
823-
pipe_state.warp_loaded_frag_B_,
824-
pipe_state.warp_frag_B_,
825-
(stage - Base::kStages + 2) * Shape::kK);
826-
}
827827
}
828828
}
829829

@@ -861,7 +861,7 @@ class Wint2xMmaMultistage :
861861

862862
// Load first warp-tile's B fragment from shared memory
863863
this->warp_tile_iterator_B_.set_kgroup_index(0);
864-
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
864+
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
865865
++this->warp_tile_iterator_B_;
866866

867867
#if 0
@@ -907,14 +907,6 @@ class Wint2xMmaMultistage :
907907
}
908908
#endif
909909

910-
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
911-
pipe_state.warp_frag_code_scale_,
912-
pipe_state.warp_frag_code_zp_,
913-
pipe_state.warp_frag_super_scale_,
914-
pipe_state.warp_loaded_frag_B_,
915-
pipe_state.warp_frag_B_,
916-
0);
917-
918910
#if 0
919911
if (TransformBAfterLDS::result_type::kElements == 64) {
920912
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);

0 commit comments

Comments
 (0)