Skip to content

Commit 3434641

Browse files
committed
Update import of ids_for_prompt and fix some formatting
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent f7c458e commit 3434641

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def __download_file(url, filename):
3737
try:
3838
response = requests.get(url, stream=True)
3939
response.raise_for_status()
40-
40+
4141
with open(filename, 'wb') as file:
4242
for chunk in response.iter_content(chunk_size=8192):
4343
file.write(chunk)
4444
print(f"Successfully downloaded {filename}")
45-
45+
4646
except requests.exceptions.RequestException as e:
4747
print(f"An error occurred: {e}")
4848

4949
def __sample_requests(
50-
prompt_list: List[str],
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,
@@ -67,16 +67,14 @@ def __sample_requests(
6767
# Tokenize the prompts and completions.
6868
prompt = prompt_list[i]
6969
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
70-
70+
7171
prompt_len = len(prompt_token_ids)
7272
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
7373
# Prune too short or too long sequences.
7474
continue
7575
filtered_dataset.append((prompt, prompt_len))
76-
77-
return filtered_dataset
78-
7976

77+
return filtered_dataset
8078

8179
def sample_sharegpt_requests(
8280
dataset_path: str,
@@ -96,15 +94,15 @@ def sample_sharegpt_requests(
9694
# Filter out the conversations with less than 2 turns.
9795
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
9896
dataset = [data["conversations"][0]["value"] for data in dataset]
99-
97+
10098
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
10199

102100
def sample_squad_v2_qa_requests(
103101
dataset_path: str,
104-
num_requests: int,
105-
tokenizer: BaseTokenizer,
106-
prompt_length_min: int = 32,
107-
prompt_length_max: int = 64,
102+
num_requests: int,
103+
tokenizer: BaseTokenizer,
104+
prompt_length_min: int = 32,
105+
prompt_length_max: int = 64,
108106
seed: Optional[int] = None
109107
) -> List[Tuple[str, int]]:
110108
from datasets import load_dataset
@@ -113,10 +111,10 @@ def sample_squad_v2_qa_requests(
113111
ds = load_dataset(dataset_path)['train']
114112
else:
115113
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
116-
117-
114+
115+
118116
ds = [f"{data['context']}\n{data['question']}" for data in ds]
119117

120118
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
121-
119+
122120

aiu_fms_testing_utils/utils/decoders_utils.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
# Local Packages
18-
from aiu_fms_testing_utils.utils import warmup_model
18+
from aiu_fms_testing_utils.utils import ids_for_prompt, warmup_model
1919
from aiu_fms_testing_utils.utils.aiu_setup import dprint, local_rank
2020

2121

@@ -34,12 +34,10 @@ def __init__(
3434
self.args = args
3535
self.device = device
3636

37-
self.add_special_tokens = False
3837
self.has_padding = True
3938
self.max_len = 0
4039
self.extra_generation_kwargs = {}
4140

42-
# !!! Inference arguments (hardcoded, as in the original script)
4341
self.do_sample = [False]
4442
self.use_cache = [args.no_use_cache] # True/False identical with greedy iff `torch.use_deterministic_algorithms(True)`
4543

@@ -58,16 +56,6 @@ def validate_decoder_arguments(self):
5856
f"Architecture {args.architecture} should be run as an encoder model."
5957
)
6058

61-
def ids_for_prompt(self, prompt):
62-
"""Process textual prompt and return tokenized ids."""
63-
64-
tokens = self.tokenizer.tokenize(prompt)
65-
ids = self.tokenizer.convert_tokens_to_ids(tokens)
66-
if self.add_special_tokens:
67-
ids = [self.tokenizer.bos_token_id] + ids
68-
ids = torch.tensor(ids, dtype=torch.long, device=self.device)
69-
return ids
70-
7159
def truncate_prompts_to_max_length(self, prompts, max_len, max_allowed_length):
7260
"""Truncate a series of prompts to a selected max length.
7361
This function ensures prompt truncation prior to padding the input ids."""
@@ -83,10 +71,6 @@ def process_eval_set(self):
8371
"""
8472

8573
args = self.args
86-
self.add_special_tokens = (
87-
self.tokenizer.bos_token_id != self.tokenizer.eos_token_id
88-
)
89-
9074
if args.prompt_path != "":
9175
# Before creating the Path object, check if prompt_path has a glob pattern
9276
if isinstance(args.prompt_path, str):
@@ -114,50 +98,69 @@ def process_eval_set(self):
11498
prompt_file_paths = [prompt_path]
11599

116100
# Check if we found some files
117-
assert len(prompt_file_paths) > 0, f"Can't find any prompt files at {prompt_path}"
101+
assert len(prompt_file_paths) > 0, (
102+
f"Can't find any prompt files at {prompt_path}"
103+
)
118104

119105
# Check if we have enough files
120-
assert (
121-
len(prompt_file_paths) >= args.batch_size
122-
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}"
106+
assert len(prompt_file_paths) >= args.batch_size, (
107+
f"Not enough prompt files at {prompt_path} "
108+
f"for a batch size of {args.batch_size}"
109+
)
123110

124111
prompts = []
125112
for i, prompt_file_path in enumerate(prompt_file_paths):
126113
if i == args.batch_size:
127114
break
128-
prompts.append(self.ids_for_prompt(prompt_file_path.read_text(encoding="utf-8")))
115+
prompts.append(
116+
ids_for_prompt(
117+
prompt_file_path.read_text(encoding="utf-8"),
118+
self.tokenizer,
119+
)
120+
)
129121
else:
130122
if args.prompt_type == "chat":
131-
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
132-
123+
template = (
124+
"Below is an instruction that describes a task. Write a response "
125+
"that appropriately completes the request.\n\n### Instruction:"
126+
"\n{}\n\n### Response:"
127+
)
133128
prompt1 = template.format(
134129
"Provide a list of instructions for preparing chicken soup."
135130
)
136131
prompt2 = template.format("Explain some popular greetings in Spanish.")
137132
prompt3 = template.format("Explain to me why ignorance is bliss.")
138133
prompt4 = template.format(
139-
"I have just come into a very large sum of money. Provide me a list of things that I can do with my new found wealth."
134+
"I have just come into a very large sum of money. Provide me a "
135+
"list of things that I can do with my new found wealth."
140136
)
141137
elif args.prompt_type == "code":
142-
template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]"
138+
template = (
139+
"[INST] Write code to solve the following coding problem that "
140+
"obeys the constraints and passes the example test cases. "
141+
"Please wrap your code answer using ```:\n{}\n[/INST]"
142+
)
143143
prompt1 = template.format("Write a bubble sort function in python.")
144144
prompt2 = template.format(
145-
"Using the Java streams API, write a simple function which will get the cumulative sum of a list of integers."
145+
"Using the Java streams API, write a simple function which will "
146+
"get the cumulative sum of a list of integers."
146147
)
147148
prompt3 = template.format(
148-
"In bash, how do I list all directories and sub-directories which contain a .py file."
149+
"In bash, how do I list all directories and sub-directories which "
150+
"contain a .py file."
149151
)
150152
prompt4 = template.format(
151-
"Write a simple decorator in python which will modify all string inputs to ints if possible."
153+
"Write a simple decorator in python which will modify all string "
154+
"inputs to ints if possible."
152155
)
153156
else:
154157
dprint("prompt_type must be one of chat or code")
155158
exit()
156159

157-
prompt1 = self.ids_for_prompt(prompt1)
158-
prompt2 = self.ids_for_prompt(prompt2)
159-
prompt3 = self.ids_for_prompt(prompt3)
160-
prompt4 = self.ids_for_prompt(prompt4)
160+
prompt1 = ids_for_prompt(prompt1, self.tokenizer)
161+
prompt2 = ids_for_prompt(prompt2, self.tokenizer)
162+
prompt3 = ids_for_prompt(prompt3, self.tokenizer)
163+
prompt4 = ids_for_prompt(prompt4, self.tokenizer)
161164
prompts = [prompt1, prompt2, prompt3, prompt4]
162165
prompts = prompts * ((args.batch_size // 4) + 1)
163166
prompts = prompts[: args.batch_size]

0 commit comments

Comments
 (0)