@@ -571,6 +571,97 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
571
571
tl .atomic_add (ddA_cumsum_ptrs + stride_ddA_cs_csize , ddA_cs , mask = offs_m < chunk_size - 1 )
572
572
573
573
574
+ @triton .autotune (
575
+ configs = [
576
+ triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 }, num_stages = 3 , num_warps = 8 ),
577
+ triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
578
+ triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
579
+ triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
580
+ triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
581
+ triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
582
+ triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 32 }, num_stages = 5 , num_warps = 2 ),
583
+ triton .Config ({'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 5 , num_warps = 2 ),
584
+ triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 2 ),
585
+ ],
586
+ key = ['hdim' , 'dstate' , 'chunk_size' ],
587
+ )
588
+ @triton .jit
589
+ def _chunk_state_varlen_kernel (
590
+ # Pointers to matrices
591
+ x_ptr , b_ptr , dt_ptr , dA_cumsum_ptr , chunk_states_ptr , cu_seqlens_ptr , states_ptr ,
592
+ # Matrix dimensions
593
+ hdim , dstate , chunk_size ,
594
+ seqlen , nheads_ngroups_ratio ,
595
+ # Strides
596
+ stride_x_seqlen , stride_x_head , stride_x_hdim ,
597
+ stride_b_seqlen , stride_b_head , stride_b_dstate ,
598
+ stride_dt_chunk , stride_dt_head , stride_dt_csize ,
599
+ stride_dA_cs_chunk , stride_dA_cs_head , stride_dA_cs_csize ,
600
+ stride_chunk_states_chunk , stride_chunk_states_head , stride_chunk_states_hdim , stride_chunk_states_dstate ,
601
+ stride_states_batch , stride_states_head , stride_states_hdim , stride_states_dstate ,
602
+ # Meta-parameters
603
+ BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr ,
604
+ ):
605
+ pid_b = tl .program_id (axis = 1 )
606
+ pid_h = tl .program_id (axis = 2 )
607
+ num_pid_n = tl .cdiv (dstate , BLOCK_SIZE_N )
608
+ pid_m = tl .program_id (axis = 0 ) // num_pid_n
609
+ pid_n = tl .program_id (axis = 0 ) % num_pid_n
610
+ end_idx = tl .load (cu_seqlens_ptr + pid_b + 1 )
611
+ pid_c = (end_idx - 1 ) // chunk_size
612
+ b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio ) * stride_b_head
613
+ x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
614
+ dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
615
+ dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
616
+ chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
617
+
618
+ offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
619
+ offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
620
+ offs_k = tl .arange (0 , BLOCK_SIZE_K )
621
+ x_ptrs = x_ptr + (offs_m [:, None ] * stride_x_hdim + offs_k [None , :] * stride_x_seqlen )
622
+ b_ptrs = b_ptr + (offs_n [None , :] * stride_b_dstate + offs_k [:, None ] * stride_b_seqlen )
623
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
624
+ dA_cs_last = tl .load (dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1 ) * stride_dA_cs_csize ).to (tl .float32 )
625
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
626
+
627
+ chunk_size_limit = end_idx - pid_c * chunk_size
628
+ start_idx = tl .load (cu_seqlens_ptr + pid_b )
629
+ start_idx_cur = tl .maximum (start_idx - pid_c * chunk_size , 0 )
630
+
631
+ acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
632
+ for k in range (0 , chunk_size_limit , BLOCK_SIZE_K ):
633
+ x = tl .load (x_ptrs , mask = (offs_m [:, None ] < hdim ) & (offs_k [None , :] < chunk_size_limit - k ) & (offs_k [None , :] >= start_idx_cur - k ), other = 0.0 )
634
+ b = tl .load (b_ptrs , mask = (offs_k [:, None ] < chunk_size_limit - k ) & (offs_n [None , :] < dstate ) & (offs_k [:, None ] >= start_idx_cur - k ), other = 0.0 ).to (tl .float32 )
635
+ dA_cs_k = tl .load (dA_cumsum_ptrs , mask = offs_k < chunk_size_limit - k , other = 0.0 ).to (tl .float32 )
636
+ dt_k = tl .load (dt_ptrs , mask = offs_k < chunk_size_limit - k , other = 0.0 ).to (tl .float32 )
637
+ scale = tl .where ((offs_k >= start_idx_cur - k ) & (offs_k < chunk_size_limit - k ),
638
+ tl .exp ((dA_cs_last - dA_cs_k )) * dt_k , 0.0 )
639
+ b *= scale [:, None ]
640
+ b = b .to (x_ptr .dtype .element_ty )
641
+ acc += tl .dot (x , b )
642
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
643
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
644
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
645
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
646
+
647
+ # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
648
+ if start_idx < pid_c * chunk_size :
649
+ chunk_states_ptrs = chunk_states_ptr + (offs_m [:, None ] * stride_chunk_states_hdim + offs_n [None , :] * stride_chunk_states_dstate )
650
+ chunk_states = tl .load (chunk_states_ptrs , mask = (offs_m [:, None ] < hdim ) & (offs_n [None , :] < dstate ), other = 0.0 ).to (tl .float32 )
651
+ # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
652
+ scale = tl .exp (dA_cs_last )
653
+ acc += chunk_states * scale
654
+
655
+ states = acc .to (states_ptr .dtype .element_ty )
656
+
657
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
658
+ offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
659
+ offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
660
+ states_ptrs = states_ptr + (offs_m [:, None ] * stride_states_hdim + offs_n [None , :] * stride_states_dstate )
661
+ c_mask = (offs_m [:, None ] < hdim ) & (offs_n [None , :] < dstate )
662
+ tl .store (states_ptrs , states , mask = c_mask )
663
+
664
+
574
665
def _chunk_cumsum_fwd (dt , A , chunk_size , dt_bias = None , dt_softplus = False , dt_limit = (0.0 , float ("inf" ))):
575
666
batch , seqlen , nheads = dt .shape
576
667
assert A .shape == (nheads ,)
@@ -790,6 +881,35 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
790
881
return ddA_cumsum
791
882
792
883
884
+ def chunk_state_varlen (B , x , dt , dA_cumsum , cu_seqlens , chunk_states ):
885
+ total_seqlen , nheads , headdim = x .shape
886
+ _ , nchunks , chunk_size = dt .shape
887
+ _ , ngroups , dstate = B .shape
888
+ batch = cu_seqlens .shape [0 ] - 1
889
+ cu_seqlens = cu_seqlens .contiguous ()
890
+ assert nheads % ngroups == 0
891
+ assert B .shape == (total_seqlen , ngroups , dstate )
892
+ assert dt .shape == (nheads , nchunks , chunk_size )
893
+ assert dA_cumsum .shape == dt .shape
894
+ assert chunk_states .shape == (nchunks , nheads , headdim , dstate )
895
+ states = torch .empty (batch , nheads , headdim , dstate , dtype = chunk_states .dtype , device = chunk_states .device )
896
+ grid = lambda META : (triton .cdiv (headdim , META ['BLOCK_SIZE_M' ]) * triton .cdiv (dstate , META ['BLOCK_SIZE_N' ]),
897
+ batch , nheads )
898
+ with torch .cuda .device (x .device .index ):
899
+ _chunk_state_varlen_kernel [grid ](
900
+ x , B , dt , dA_cumsum , chunk_states , cu_seqlens , states ,
901
+ headdim , dstate , chunk_size ,
902
+ total_seqlen , nheads // ngroups ,
903
+ x .stride (0 ), x .stride (1 ), x .stride (2 ),
904
+ B .stride (0 ), B .stride (1 ), B .stride (2 ),
905
+ dt .stride (1 ), dt .stride (0 ), dt .stride (2 ),
906
+ dA_cumsum .stride (1 ), dA_cumsum .stride (0 ), dA_cumsum .stride (2 ),
907
+ chunk_states .stride (0 ), chunk_states .stride (1 ), chunk_states .stride (2 ), chunk_states .stride (3 ),
908
+ states .stride (0 ), states .stride (1 ), states .stride (2 ), states .stride (3 ),
909
+ )
910
+ return states
911
+
912
+
793
913
class ChunkStateFn (torch .autograd .Function ):
794
914
795
915
@staticmethod
0 commit comments