Skip to content

Commit 103dfbd

Browse files
committed
[BugFix]fix ring_mla invalid param
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent 1648906 commit 103dfbd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ def _list_to_tensor(lst, device, dtype=torch.int32):
13261326
attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], dtype=torch.int32)
13271327
head_attn_nomask_seqlens = torch.tensor([chunk_seqlens, kv_with_q_head_nomask_seqlens], dtype=torch.int32)
13281328
tail_attn_nomask_seqlens = torch.tensor([chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32)
1329-
cp_prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=torch.bfloat16), 1)
1329+
cp_prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=self.dtype), 1)
13301330

13311331
self.extra_long_seq_kwargs = {
13321332
'attn_mask_seqlens': attn_mask_seqlens,

0 commit comments

Comments
 (0)