Skip to content

Commit ae19b65

Browse files
authored
Add position_ids argument to DistributedFalconModel (#525)
1 parent 1d9401d commit ae19b65

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/petals/models/falcon/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def forward(
4747
input_ids: Optional[torch.LongTensor] = None,
4848
past_key_values: Optional[RemotePastKeyValues] = None,
4949
attention_mask: Optional[torch.Tensor] = None,
50+
position_ids: Optional[torch.LongTensor] = None,
5051
head_mask: Optional[torch.LongTensor] = None,
5152
inputs_embeds: Optional[torch.LongTensor] = None,
5253
use_cache: Optional[bool] = None,
@@ -68,6 +69,9 @@ def forward(
6869
assert (
6970
attention_mask is None or (attention_mask == 1).all()
7071
), f"Custom attention masks are not supported, {attention_mask=}"
72+
assert (
73+
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
74+
), f"Non-consecutive position_ids are not supported, {position_ids=}"
7175
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
7276
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
7377
assert not output_attentions, f"{output_attentions=} is not supported"

0 commit comments

Comments
 (0)