@@ -95,8 +95,10 @@ def model_input_split_v1_mla_attn(
95
95
[num_prefills_pre , num_prefills_post
96
96
] = split_attn_int_type (attn_metadata .num_prefills ,
97
97
max (0 , seq_index - attn_metadata .num_decodes ))
98
- seq_lens = attn_metadata .prefill .seq_lens if attn_metadata .num_prefills > 0 else attn_metadata .decode .seq_lens
99
- [seq_lens_pre , seq_lens_post ] = split_attn_tensor_type (seq_lens , seq_index )
98
+ seq_lens = attn_metadata .seq_lens if attn_metadata .num_prefills > 0 else attn_metadata .decode .seq_lens
99
+ [seq_lens_pre , seq_lens_post
100
+ ] = split_attn_tensor_type (seq_lens ,
101
+ max (0 , seq_index - attn_metadata .num_decodes ))
100
102
101
103
query_start_loc_pre = query_start_loc_post = None
102
104
if attn_metadata .query_start_loc is not None :
@@ -153,7 +155,7 @@ def model_input_split_v1_mla_attn(
153
155
attn_metadata .num_decodes :]
154
156
) - attn_metadata .prefill .query_start_loc [seq_index -
155
157
attn_metadata .num_decodes ]
156
- context_len_pre = seq_lens_pre [ attn_metadata . num_decodes :]
158
+ context_len_pre = seq_lens_pre
157
159
context_len_post = seq_lens_post
158
160
prefill_max_query_len_pre = max (prefill_query_lens_pre )
159
161
prefill_max_query_len_post = max (prefill_query_lens_post )
0 commit comments