Skip to content

Commit a07ff1b

Browse files
authored
fix unpacking error (#507)
key_value_memory_dict in MHA module is a tuple of kv_cache and conv1d
1 parent 49ddf83 commit a07ff1b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mamba_ssm/modules/mha.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _update_kvcache_attention(self, q, kv, inference_params):
180180
).transpose(1, 2)
181181
else:
182182
batch = q.shape[0]
183-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
183+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184+
kv_cache = kv_cache[:batch]
184185
cache_seqlens = (
185186
inference_params.lengths_per_sample[:batch]
186187
if inference_params.lengths_per_sample is not None

0 commit comments

Comments
 (0)