-
Notifications
You must be signed in to change notification settings - Fork 19
fix: inputs of prompt path by trimming at end. #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,6 +193,13 @@ | |
default=0, | ||
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" | ||
) | ||
parser.add_argument( | ||
'--no_prompt_trim', | ||
action="store_true", | ||
help="Disable rtrip() from input prompts defined in --prompt_path. " | ||
"By default, inputs prompts are trimmed at end due to POSIX enforce files ends with new line char. " | ||
"This behavior can change the tokenization of the prompts and impact reproducibility." | ||
) | ||
args = parser.parse_args() | ||
|
||
if args.quantization == "gptq": | ||
|
@@ -460,11 +467,17 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
len(prompt_file_paths) >= args.batch_size | ||
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" | ||
|
||
no_prompt_trim = args.no_prompt_trim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why define variable here if it's only used once? |
||
prompts = [] | ||
for i, prompt_file_path in enumerate(prompt_file_paths): | ||
if i == args.batch_size: | ||
break | ||
prompts.append(ids_for_prompt(prompt_file_path.read_text(encoding="utf-8"))) | ||
|
||
prompt_text = prompt_file_path.read_text(encoding="utf-8") | ||
if not no_prompt_trim: | ||
prompt_text = prompt_text.rstrip() | ||
|
||
prompts.append(ids_for_prompt(prompt_text)) | ||
|
||
else: | ||
if args.prompt_type == "chat": | ||
|
@@ -510,14 +523,16 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
max_allowed_length = args.max_prompt_length | ||
|
||
has_padding = args.batch_size > 1 or padding_length != 0 | ||
max_len = max([len(prompt) for prompt in prompts]) | ||
prompts_lens = [len(prompt) for prompt in prompts] | ||
max_len = max(prompts_lens) | ||
|
||
if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len: | ||
dprint( | ||
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set." | ||
) | ||
exit(1) | ||
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length) | ||
|
||
if has_padding: | ||
ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length) | ||
else: | ||
|
@@ -528,27 +543,39 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): | |
def print_result(result, result_idx: int): | ||
if local_rank != 0: | ||
return | ||
|
||
if has_padding: | ||
result = generation.trim_prefix(result) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove extra space |
||
result = generation.trim_prefix(result, tokenizer.bos_token_id) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove extra space |
||
# stop at EOS token if present and remove padding | ||
if not args.no_early_termination: | ||
result = generation.truncate_after_eos(result, tokenizer.eos_token_id) | ||
|
||
prompt_len = prompts_lens[result_idx] | ||
if add_special_tokens: | ||
prompt_len -= 1 | ||
prompt = result[:prompt_len] | ||
result = result[prompt_len:] | ||
input_str = tokenizer.convert_tokens_to_string( | ||
tokenizer.convert_ids_to_tokens(prompt) | ||
) | ||
Comment on lines
+556
to
+563
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment to code why this needs to happen |
||
|
||
output_str = tokenizer.convert_tokens_to_string( | ||
tokenizer.convert_ids_to_tokens(result) | ||
) | ||
|
||
if args.output_path != "": | ||
output_path = Path(args.output_path) | ||
output_path.mkdir(parents=True, exist_ok=True) | ||
if output_path.is_dir(): | ||
file_path = output_path / f"{result_idx}.txt" | ||
with file_path.open("w", encoding="utf-8") as file: | ||
file.write(output_str + "\n") | ||
dprint(output_str) | ||
dprint(f"prompt #{result_idx}: \n'{input_str}'") | ||
dprint(f"generation #{result_idx}: \n'{output_str}'") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't the output_str already contain the input_str? |
||
# print('original', original) | ||
print() | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix typo