Skip to content

Commit d40eb6c

Browse files
authored
Fix prompt tuning after #464 (#501)
Unfortunately, running inference in models with `"ptune" in config.tuning_mode` was broken after #464.
1 parent dd4a323 commit d40eb6c

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

src/petals/client/remote_generation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ def generate(
8787
max_new_tokens is None
8888
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
8989

90+
session_max_length = self.transformer.config.pre_seq_len
9091
if max_length is not None:
91-
session_max_length = max_length
92+
session_max_length += max_length
9293
else:
93-
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
94+
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
9495
context_manager = self.inference_session(max_length=session_max_length)
9596

9697
with context_manager as session:

src/petals/models/bloom/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def forward(
7171
if inputs_embeds is None:
7272
inputs_embeds = self.word_embeddings(input_ids)
7373

74-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
74+
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
75+
if use_prompts:
7576
batch_size = inputs_embeds.shape[0]
7677
prompts, intermediate_prompts = self.get_prompt(batch_size)
7778
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -88,7 +89,7 @@ def forward(
8889
)
8990

9091
# Remove prefix
91-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
92+
if use_prompts:
9293
hidden_states = hidden_states[:, self.pre_seq_len :]
9394

9495
# Add last hidden state

src/petals/models/falcon/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def forward(
7777
if inputs_embeds is None:
7878
inputs_embeds = self.word_embeddings(input_ids)
7979

80-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
80+
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
81+
if use_prompts:
8182
batch_size = inputs_embeds.shape[0]
8283
prompts, intermediate_prompts = self.get_prompt(batch_size)
8384
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -94,7 +95,7 @@ def forward(
9495
)
9596

9697
# Remove prefix
97-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
98+
if use_prompts:
9899
hidden_states = hidden_states[:, self.pre_seq_len :]
99100

100101
# Add last hidden state

src/petals/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def forward(
7373
if inputs_embeds is None:
7474
inputs_embeds = self.embed_tokens(input_ids)
7575

76-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
76+
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
77+
if use_prompts:
7778
batch_size = inputs_embeds.shape[0]
7879
prompts, intermediate_prompts = self.get_prompt(batch_size)
7980
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -90,7 +91,7 @@ def forward(
9091
)
9192

9293
# Remove prefix
93-
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
94+
if use_prompts:
9495
hidden_states = hidden_states[:, self.pre_seq_len :]
9596

9697
# Add last hidden state

0 commit comments

Comments
 (0)