From a84cce4c5cb78544422cbc3b792e5f53f2fa2c4c Mon Sep 17 00:00:00 2001 From: Wallas Santos Date: Fri, 11 Apr 2025 15:02:29 -0300 Subject: [PATCH] fix: inputs of prompt path by trimming at end. Signed-off-by: Wallas Santos --- scripts/inference.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 472ae0fb..d2144590 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -193,6 +193,13 @@ default=0, help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" ) +parser.add_argument( + '--no_prompt_trim', + action="store_true", + help="Disable rtrip() from input prompts defined in --prompt_path. " + "By default, inputs prompts are trimmed at end due to POSIX enforce files ends with new line char. " + "This behavior can change the tokenization of the prompts and impact reproducibility." +) args = parser.parse_args() if args.quantization == "gptq": @@ -460,11 +467,17 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): len(prompt_file_paths) >= args.batch_size ), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" + no_prompt_trim = args.no_prompt_trim prompts = [] for i, prompt_file_path in enumerate(prompt_file_paths): if i == args.batch_size: break - prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8"))) + + prompt_text = prompt_file_path.read_text(encoding="utf-8") + if not no_prompt_trim: + prompt_text = prompt_text.rstrip() + + prompts.append(ids_for_prompt(prompt_text)) else: if args.prompt_type == "chat": @@ -510,7 +523,8 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): max_allowed_length = args.max_prompt_length has_padding = args.batch_size > 1 or padding_length != 0 -max_len = max([len(prompt) for prompt in prompts]) +prompts_lens = [len(prompt) for prompt in prompts] +max_len = max(prompts_lens) if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len: dprint( @@ -518,6 +532,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): ) exit(1) prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length) + if has_padding: ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length) else: @@ -528,19 +543,29 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): def print_result(result, result_idx: int): if local_rank != 0: return + if has_padding: result = generation.trim_prefix(result) - + result = generation.trim_prefix(result, tokenizer.bos_token_id) - + # stop at EOS token if present and remove padding if not args.no_early_termination: result = generation.truncate_after_eos(result, tokenizer.eos_token_id) + prompt_len = prompts_lens[result_idx] + if add_special_tokens: + prompt_len -= 1 + prompt = result[:prompt_len] + result = result[prompt_len:] + input_str = tokenizer.convert_tokens_to_string( + tokenizer.convert_ids_to_tokens(prompt) + ) + output_str = tokenizer.convert_tokens_to_string( tokenizer.convert_ids_to_tokens(result) ) - + if args.output_path != "": output_path = Path(args.output_path) output_path.mkdir(parents=True, exist_ok=True) @@ -548,7 +573,9 @@ def print_result(result, result_idx: int): file_path = output_path / f"{result_idx}.txt" with file_path.open("w", encoding="utf-8") as file: file.write(output_str + "\n") - dprint(output_str) + dprint(f"prompt #{result_idx}: \n'{input_str}'") + dprint(f"generation #{result_idx}: \n'{output_str}'") + # print('original', original) print()