Skip to content

fix: inputs of prompt path by trimming at end. #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wallashss
Copy link

This PR fix a problem of reproducibility on inference.py when we set as input prompts from file with --prompt_path.

The issue is due to POSIX files must be terminated with the new line character, so when the script reads the prompt of the file it does not remove it and it impacts the tokenization of the prompt. So, the solution is just trim right the prompt text (with rstrip()).

The default behavior is trim at right, but users can also disable this option with no_prompt_trim, if they wish to.

I also did a small change to the format of response, because I had to split the response that contains the prompt as well, and it was kinda of difficult to see them concatenated.

Signed-off-by: Wallas Santos <wallashss@ibm.com>
parser.add_argument(
'--no_prompt_trim',
action="store_true",
help="Disable rtrip() from input prompts defined in --prompt_path. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

if args.output_path != "":
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
if output_path.is_dir():
file_path = output_path / f"{result_idx}.txt"
with file_path.open("w", encoding="utf-8") as file:
file.write(output_str + "\n")
dprint(output_str)
dprint(f"prompt #{result_idx}: \n'{input_str}'")
dprint(f"generation #{result_idx}: \n'{output_str}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't the output_str already contain the input_str?

Comment on lines +556 to +563
prompt_len = prompts_lens[result_idx]
if add_special_tokens:
prompt_len -= 1
prompt = result[:prompt_len]
result = result[prompt_len:]
input_str = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(prompt)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to code why this needs to happen

result = generation.trim_prefix(result, tokenizer.bos_token_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

if has_padding:
result = generation.trim_prefix(result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

@@ -460,11 +467,17 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
len(prompt_file_paths) >= args.batch_size
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}"

no_prompt_trim = args.no_prompt_trim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define variable here if it's only used once?

Copy link
Contributor

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking for some changes to code style and formatting, as well as some clarifications

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants