@@ -192,7 +192,7 @@ def process_eval_set(self):
192
192
ids = prompts
193
193
if isinstance (ids , list ) and len (ids ) == 1 :
194
194
ids = ids [0 ].unsqueeze (0 )
195
- extra_generation_kwargs = None
195
+ extra_generation_kwargs = {}
196
196
197
197
self .extra_generation_kwargs = extra_generation_kwargs
198
198
@@ -252,15 +252,10 @@ def infer(self, ids, warmup):
252
252
max_seq_len = self .model .config .max_expected_seq_len
253
253
254
254
# Add only_last_token optimization
255
- extra_generation_kwargs = (
256
- {}
257
- if self .extra_generation_kwargs is None
258
- else self .extra_generation_kwargs
259
- )
260
- extra_generation_kwargs ["only_last_token" ] = True
255
+ self .extra_generation_kwargs ["only_last_token" ] = True
261
256
262
257
if args .device_type == "cpu" :
263
- extra_generation_kwargs ["attn_algorithm" ] = "math"
258
+ self . extra_generation_kwargs ["attn_algorithm" ] = "math"
264
259
265
260
if not args .no_early_termination and not warmup :
266
261
eos_token_id = self .tokenizer .eos_token_id
@@ -277,7 +272,7 @@ def infer(self, ids, warmup):
277
272
timing = args .timing ,
278
273
eos_token_id = eos_token_id ,
279
274
contiguous_cache = True ,
280
- extra_kwargs = extra_generation_kwargs ,
275
+ extra_kwargs = self . extra_generation_kwargs ,
281
276
)
282
277
if args .timing != "" :
283
278
result , timings = result
0 commit comments