@@ -873,7 +873,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
873
873
sequence_length : int ,
874
874
target_length : int ,
875
875
dtype : torch .dtype ,
876
- device : torch .device ,
877
876
cache_position : torch .Tensor ,
878
877
batch_size : int ,
879
878
** kwargs ,
@@ -906,16 +905,18 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
906
905
else :
907
906
min_dtype = torch .finfo (dtype ).min
908
907
causal_mask = torch .full (
909
- (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
908
+ (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = cache_position . device
910
909
)
911
910
if sequence_length != 1 :
912
911
causal_mask = torch .triu (causal_mask , diagonal = 1 )
913
- causal_mask *= torch .arange (target_length , device = device ) > cache_position . to ( device ) .reshape (- 1 , 1 )
912
+ causal_mask *= torch .arange (target_length , device = cache_position . device ) > cache_position .reshape (- 1 , 1 )
914
913
causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
915
914
if attention_mask is not None :
916
915
causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
917
916
mask_length = attention_mask .shape [- 1 ]
918
- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (device )
917
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (
918
+ cache_position .device
919
+ )
919
920
padding_mask = padding_mask == 0
920
921
causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
921
922
padding_mask , min_dtype
0 commit comments