File tree Expand file tree Collapse file tree 4 files changed +12
-8
lines changed Expand file tree Collapse file tree 4 files changed +12
-8
lines changed Original file line number Diff line number Diff line change @@ -87,10 +87,11 @@ def generate(
87
87
max_new_tokens is None
88
88
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
89
89
90
+ session_max_length = self .transformer .config .pre_seq_len
90
91
if max_length is not None :
91
- session_max_length = max_length
92
+ session_max_length + = max_length
92
93
else :
93
- session_max_length = (inputs .shape [1 ] if inputs is not None else 0 ) + max_new_tokens
94
+ session_max_length + = (inputs .shape [1 ] if inputs is not None else 0 ) + max_new_tokens
94
95
context_manager = self .inference_session (max_length = session_max_length )
95
96
96
97
with context_manager as session :
Original file line number Diff line number Diff line change @@ -71,7 +71,8 @@ def forward(
71
71
if inputs_embeds is None :
72
72
inputs_embeds = self .word_embeddings (input_ids )
73
73
74
- if self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .h .position == 0 :
74
+ use_prompts = self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .h .position == 0
75
+ if use_prompts :
75
76
batch_size = inputs_embeds .shape [0 ]
76
77
prompts , intermediate_prompts = self .get_prompt (batch_size )
77
78
inputs_embeds = torch .cat ([prompts , inputs_embeds ], dim = 1 )
@@ -88,7 +89,7 @@ def forward(
88
89
)
89
90
90
91
# Remove prefix
91
- if self . config . tuning_mode and "ptune" in self . config . tuning_mode :
92
+ if use_prompts :
92
93
hidden_states = hidden_states [:, self .pre_seq_len :]
93
94
94
95
# Add last hidden state
Original file line number Diff line number Diff line change @@ -77,7 +77,8 @@ def forward(
77
77
if inputs_embeds is None :
78
78
inputs_embeds = self .word_embeddings (input_ids )
79
79
80
- if self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .h .position == 0 :
80
+ use_prompts = self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .h .position == 0
81
+ if use_prompts :
81
82
batch_size = inputs_embeds .shape [0 ]
82
83
prompts , intermediate_prompts = self .get_prompt (batch_size )
83
84
inputs_embeds = torch .cat ([prompts , inputs_embeds ], dim = 1 )
@@ -94,7 +95,7 @@ def forward(
94
95
)
95
96
96
97
# Remove prefix
97
- if self . config . tuning_mode and "ptune" in self . config . tuning_mode :
98
+ if use_prompts :
98
99
hidden_states = hidden_states [:, self .pre_seq_len :]
99
100
100
101
# Add last hidden state
Original file line number Diff line number Diff line change @@ -73,7 +73,8 @@ def forward(
73
73
if inputs_embeds is None :
74
74
inputs_embeds = self .embed_tokens (input_ids )
75
75
76
- if self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .layers .position == 0 :
76
+ use_prompts = self .config .tuning_mode and "ptune" in self .config .tuning_mode and self .layers .position == 0
77
+ if use_prompts :
77
78
batch_size = inputs_embeds .shape [0 ]
78
79
prompts , intermediate_prompts = self .get_prompt (batch_size )
79
80
inputs_embeds = torch .cat ([prompts , inputs_embeds ], dim = 1 )
@@ -90,7 +91,7 @@ def forward(
90
91
)
91
92
92
93
# Remove prefix
93
- if self . config . tuning_mode and "ptune" in self . config . tuning_mode :
94
+ if use_prompts :
94
95
hidden_states = hidden_states [:, self .pre_seq_len :]
95
96
96
97
# Add last hidden state
You can’t perform that action at this time.
0 commit comments