Skip to content

Commit dfd6758

Browse files
committed
Adjust extra_generation_kwargs handling
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent 0437c50 commit dfd6758

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

aiu_fms_testing_utils/utils/decoders_utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def process_eval_set(self):
192192
ids = prompts
193193
if isinstance(ids, list) and len(ids) == 1:
194194
ids = ids[0].unsqueeze(0)
195-
extra_generation_kwargs = None
195+
extra_generation_kwargs = {}
196196

197197
self.extra_generation_kwargs = extra_generation_kwargs
198198

@@ -252,15 +252,10 @@ def infer(self, ids, warmup):
252252
max_seq_len = self.model.config.max_expected_seq_len
253253

254254
# Add only_last_token optimization
255-
extra_generation_kwargs = (
256-
{}
257-
if self.extra_generation_kwargs is None
258-
else self.extra_generation_kwargs
259-
)
260-
extra_generation_kwargs["only_last_token"] = True
255+
self.extra_generation_kwargs["only_last_token"] = True
261256

262257
if args.device_type == "cpu":
263-
extra_generation_kwargs["attn_algorithm"] = "math"
258+
self.extra_generation_kwargs["attn_algorithm"] = "math"
264259

265260
if not args.no_early_termination and not warmup:
266261
eos_token_id = self.tokenizer.eos_token_id
@@ -277,7 +272,7 @@ def infer(self, ids, warmup):
277272
timing=args.timing,
278273
eos_token_id=eos_token_id,
279274
contiguous_cache=True,
280-
extra_kwargs=extra_generation_kwargs,
275+
extra_kwargs=self.extra_generation_kwargs,
281276
)
282277
if args.timing != "":
283278
result, timings = result

0 commit comments

Comments
 (0)