Skip to content

fix: inputs of prompt path by trimming at end. #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@
default=0,
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
)
parser.add_argument(
'--no_prompt_trim',
action="store_true",
help="Disable rtrip() from input prompts defined in --prompt_path. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix typo

"By default, inputs prompts are trimmed at end due to POSIX enforce files ends with new line char. "
"This behavior can change the tokenization of the prompts and impact reproducibility."
)
args = parser.parse_args()

if args.quantization == "gptq":
Expand Down Expand Up @@ -460,11 +467,17 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
len(prompt_file_paths) >= args.batch_size
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}"

no_prompt_trim = args.no_prompt_trim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define variable here if it's only used once?

prompts = []
for i, prompt_file_path in enumerate(prompt_file_paths):
if i == args.batch_size:
break
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8")))

prompt_text = prompt_file_path.read_text(encoding="utf-8")
if not no_prompt_trim:
prompt_text = prompt_text.rstrip()

prompts.append(ids_for_prompt(prompt_text))

else:
if args.prompt_type == "chat":
Expand Down Expand Up @@ -510,14 +523,16 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
max_allowed_length = args.max_prompt_length

has_padding = args.batch_size > 1 or padding_length != 0
max_len = max([len(prompt) for prompt in prompts])
prompts_lens = [len(prompt) for prompt in prompts]
max_len = max(prompts_lens)

if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len:
dprint(
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
)
exit(1)
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length)

if has_padding:
ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length)
else:
Expand All @@ -528,27 +543,39 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
def print_result(result, result_idx: int):
if local_rank != 0:
return

if has_padding:
result = generation.trim_prefix(result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

result = generation.trim_prefix(result, tokenizer.bos_token_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space

# stop at EOS token if present and remove padding
if not args.no_early_termination:
result = generation.truncate_after_eos(result, tokenizer.eos_token_id)

prompt_len = prompts_lens[result_idx]
if add_special_tokens:
prompt_len -= 1
prompt = result[:prompt_len]
result = result[prompt_len:]
input_str = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(prompt)
)
Comment on lines +556 to +563
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to code why this needs to happen


output_str = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(result)
)

if args.output_path != "":
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
if output_path.is_dir():
file_path = output_path / f"{result_idx}.txt"
with file_path.open("w", encoding="utf-8") as file:
file.write(output_str + "\n")
dprint(output_str)
dprint(f"prompt #{result_idx}: \n'{input_str}'")
dprint(f"generation #{result_idx}: \n'{output_str}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't the output_str already contain the input_str?

# print('original', original)
print()


Expand Down