@@ -44,6 +44,7 @@ def __init__(
44
44
45
45
self ._seq_len_cached = attn_mask .shape [0 ]
46
46
self .attn_mask_cache = attn_mask
47
+ self .chunked_prefill_attn_mask = torch .triu (torch .ones (2048 , 2048 ), diagonal = 1 ).to (torch .int8 )
47
48
48
49
@staticmethod
49
50
def get_mask_scale_factor (dtype : torch .dtype = torch .float16 ):
@@ -66,24 +67,9 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
66
67
67
68
def get_splitfuse_attn_mask (
68
69
self ,
69
- seq_lens : torch .Tensor ,
70
- position : torch .Tensor ,
71
- dtype : torch .dtype ,
72
70
device : torch .device ,
73
71
) -> torch .Tensor :
74
- if dtype not in [torch .float16 , torch .bfloat16 ]:
75
- raise ValueError (
76
- "splitfuse_attn_mask now only supports bf16 and fp16" )
77
- max_seq_len = max (seq_lens , default = 0 )
78
- self ._update_attn_cache (max_seq_len , dtype )
79
- # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
80
- # is not the same. Fix this in the future when kernel is ready.
81
- mask_scale_factor = AttentionMaskBuilder .get_mask_scale_factor (dtype )
82
- attn_mask = torch .index_select (self .attn_mask_cache ,
83
- dim = 0 ,
84
- index = position )[:, :max_seq_len ]
85
- attn_mask *= mask_scale_factor
86
- return attn_mask .contiguous ().to (device , non_blocking = True )
72
+ return self .chunked_prefill_attn_mask .to (device )
87
73
88
74
def _update_attn_cache (self , seqlen : int , dtype : torch .dtype ):
89
75
if seqlen > self ._seq_len_cached :
0 commit comments