File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ def forward(
47
47
input_ids : Optional [torch .LongTensor ] = None ,
48
48
past_key_values : Optional [RemotePastKeyValues ] = None ,
49
49
attention_mask : Optional [torch .Tensor ] = None ,
50
+ position_ids : Optional [torch .LongTensor ] = None ,
50
51
head_mask : Optional [torch .LongTensor ] = None ,
51
52
inputs_embeds : Optional [torch .LongTensor ] = None ,
52
53
use_cache : Optional [bool ] = None ,
@@ -68,6 +69,9 @@ def forward(
68
69
assert (
69
70
attention_mask is None or (attention_mask == 1 ).all ()
70
71
), f"Custom attention masks are not supported, { attention_mask = } "
72
+ assert (
73
+ position_ids is None or (position_ids [:, 1 :] - position_ids [:, :- 1 ] == 1 ).all ()
74
+ ), f"Non-consecutive position_ids are not supported, { position_ids = } "
71
75
assert head_mask is None , f"Custom head masks are not supported, { head_mask = } "
72
76
assert use_cache is None or use_cache , f"{ use_cache = } is not supported"
73
77
assert not output_attentions , f"{ output_attentions = } is not supported"
You can’t perform that action at this time.
0 commit comments