File tree Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Original file line number Diff line number Diff line change 17
17
18
18
def _generate_attn_mask (max_seq_len , dtype ):
19
19
# Construct lower triangle matrix.
20
- mask_flag = torch .ones (
21
- (max_seq_len , max_seq_len ),
22
- dtype = torch .bool ).tril_ ().view (max_seq_len , max_seq_len )
20
+ mask_flag = torch .ones ((max_seq_len , max_seq_len ),
21
+ dtype = torch .bool ).tril_ ()
23
22
# Create upper triangle matrix used to mark mask positions.
24
23
mask_flag = ~ mask_flag
25
24
# Currently for fp16 dtype, the mask value should be set to -inf.
26
25
# TODO: Eliminate this part in the future.
27
- if dtype == torch .float16 :
28
- mask_value = torch .finfo (torch .float32 ).min
29
- else :
30
- mask_value = 1
26
+ mask_value = float ('-inf' ) if dtype == torch .float16 else 1
31
27
attn_mask = torch .zeros (size = (max_seq_len , max_seq_len ), dtype = dtype ) \
32
28
.masked_fill_ (mask_flag , mask_value )
33
29
return attn_mask
You can’t perform that action at this time.
0 commit comments