Skip to content

Conversation

wallashss
Copy link

This PR fix a problem of reproducibility on inference.py when we set as input prompts from file with --prompt_path.

The issue is due to POSIX files must be terminated with the new line character, so when the script reads the prompt of the file it does not remove it and it impacts the tokenization of the prompt. So, the solution is just trim right the prompt text (with rstrip()).

The default behavior is trim at right, but users can also disable this option with no_prompt_trim, if they wish to.

I also did a small change to the format of response, because I had to split the response that contains the prompt as well, and it was kinda of difficult to see them concatenated.

Signed-off-by: Wallas Santos <wallashss@ibm.com>
parser.add_argument(
'--no_prompt_trim',
action="store_true",
help="Disable rtrip() from input prompts defined in --prompt_path. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

if args.output_path != "":
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
if output_path.is_dir():
file_path = output_path / f"{result_idx}.txt"
with file_path.open("w", encoding="utf-8") as file:
file.write(output_str + "\n")
dprint(output_str)
dprint(f"prompt #{result_idx}: \n'{input_str}'")
dprint(f"generation #{result_idx}: \n'{output_str}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't the output_str already contain the input_str?

Comment on lines +556 to +563
prompt_len = prompts_lens[result_idx]
if add_special_tokens:
prompt_len -= 1
prompt = result[:prompt_len]
result = result[prompt_len:]
input_str = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(prompt)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to code why this needs to happen

result = generation.trim_prefix(result, tokenizer.bos_token_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

if has_padding:
result = generation.trim_prefix(result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

@@ -460,11 +467,17 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
len(prompt_file_paths) >= args.batch_size
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}"

no_prompt_trim = args.no_prompt_trim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define variable here if it's only used once?

Copy link
Contributor

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking for some changes to code style and formatting, as well as some clarifications

@tharapalanivel
Copy link
Collaborator

Hi @wallashss is this change still required and actively being worked on? Thank you

@wallashss
Copy link
Author

Hey, these were just minor improvements that I thought the first time I used these scripts, I think now they were not so needed anymore and I don't have bandwidth to look those again. You can close this if you wish so. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants