Skip to content

Commit d386fc9

Browse files
committed
Update import of ids_for_prompt and fix some formatting
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent f7c458e commit d386fc9

File tree

2 files changed

+49
-38
lines changed

2 files changed

+49
-38
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def __download_file(url, filename):
3737
try:
3838
response = requests.get(url, stream=True)
3939
response.raise_for_status()
40-
40+
4141
with open(filename, 'wb') as file:
4242
for chunk in response.iter_content(chunk_size=8192):
4343
file.write(chunk)
4444
print(f"Successfully downloaded {filename}")
45-
45+
4646
except requests.exceptions.RequestException as e:
4747
print(f"An error occurred: {e}")
4848

4949
def __sample_requests(
50-
prompt_list: List[str],
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,
@@ -67,16 +67,14 @@ def __sample_requests(
6767
# Tokenize the prompts and completions.
6868
prompt = prompt_list[i]
6969
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
70-
70+
7171
prompt_len = len(prompt_token_ids)
7272
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
7373
# Prune too short or too long sequences.
7474
continue
7575
filtered_dataset.append((prompt, prompt_len))
76-
77-
return filtered_dataset
78-
7976

77+
return filtered_dataset
8078

8179
def sample_sharegpt_requests(
8280
dataset_path: str,
@@ -96,15 +94,15 @@ def sample_sharegpt_requests(
9694
# Filter out the conversations with less than 2 turns.
9795
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
9896
dataset = [data["conversations"][0]["value"] for data in dataset]
99-
97+
10098
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
10199

102100
def sample_squad_v2_qa_requests(
103101
dataset_path: str,
104-
num_requests: int,
105-
tokenizer: BaseTokenizer,
106-
prompt_length_min: int = 32,
107-
prompt_length_max: int = 64,
102+
num_requests: int,
103+
tokenizer: BaseTokenizer,
104+
prompt_length_min: int = 32,
105+
prompt_length_max: int = 64,
108106
seed: Optional[int] = None
109107
) -> List[Tuple[str, int]]:
110108
from datasets import load_dataset
@@ -113,10 +111,10 @@ def sample_squad_v2_qa_requests(
113111
ds = load_dataset(dataset_path)['train']
114112
else:
115113
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
116-
117-
114+
115+
118116
ds = [f"{data['context']}\n{data['question']}" for data in ds]
119117

120118
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
121-
119+
122120

aiu_fms_testing_utils/utils/decoders_utils.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
# Local Packages
18-
from aiu_fms_testing_utils.utils import warmup_model
18+
from aiu_fms_testing_utils.utils import ids_for_prompt, warmup_model
1919
from aiu_fms_testing_utils.utils.aiu_setup import dprint, local_rank
2020

2121

@@ -34,12 +34,10 @@ def __init__(
3434
self.args = args
3535
self.device = device
3636

37-
self.add_special_tokens = False
3837
self.has_padding = True
3938
self.max_len = 0
4039
self.extra_generation_kwargs = {}
4140

42-
# !!! Inference arguments (hardcoded, as in the original script)
4341
self.do_sample = [False]
4442
self.use_cache = [args.no_use_cache] # True/False identical with greedy iff `torch.use_deterministic_algorithms(True)`
4543

@@ -83,10 +81,6 @@ def process_eval_set(self):
8381
"""
8482

8583
args = self.args
86-
self.add_special_tokens = (
87-
self.tokenizer.bos_token_id != self.tokenizer.eos_token_id
88-
)
89-
9084
if args.prompt_path != "":
9185
# Before creating the Path object, check if prompt_path has a glob pattern
9286
if isinstance(args.prompt_path, str):
@@ -114,50 +108,69 @@ def process_eval_set(self):
114108
prompt_file_paths = [prompt_path]
115109

116110
# Check if we found some files
117-
assert len(prompt_file_paths) > 0, f"Can't find any prompt files at {prompt_path}"
111+
assert len(prompt_file_paths) > 0, (
112+
f"Can't find any prompt files at {prompt_path}"
113+
)
118114

119115
# Check if we have enough files
120-
assert (
121-
len(prompt_file_paths) >= args.batch_size
122-
), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}"
116+
assert len(prompt_file_paths) >= args.batch_size, (
117+
f"Not enough prompt files at {prompt_path} "
118+
f"for a batch size of {args.batch_size}"
119+
)
123120

124121
prompts = []
125122
for i, prompt_file_path in enumerate(prompt_file_paths):
126123
if i == args.batch_size:
127124
break
128-
prompts.append(self.ids_for_prompt(prompt_file_path.read_text(encoding="utf-8")))
125+
prompts.append(
126+
ids_for_prompt(
127+
prompt_file_path.read_text(encoding="utf-8"),
128+
self.tokenizer,
129+
)
130+
)
129131
else:
130132
if args.prompt_type == "chat":
131-
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
132-
133+
template = (
134+
"Below is an instruction that describes a task. Write a response "
135+
"that appropriately completes the request.\n\n### Instruction:"
136+
"\n{}\n\n### Response:"
137+
)
133138
prompt1 = template.format(
134139
"Provide a list of instructions for preparing chicken soup."
135140
)
136141
prompt2 = template.format("Explain some popular greetings in Spanish.")
137142
prompt3 = template.format("Explain to me why ignorance is bliss.")
138143
prompt4 = template.format(
139-
"I have just come into a very large sum of money. Provide me a list of things that I can do with my new found wealth."
144+
"I have just come into a very large sum of money. Provide me a "
145+
"list of things that I can do with my new found wealth."
140146
)
141147
elif args.prompt_type == "code":
142-
template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]"
148+
template = (
149+
"[INST] Write code to solve the following coding problem that "
150+
"obeys the constraints and passes the example test cases. "
151+
"Please wrap your code answer using ```:\n{}\n[/INST]"
152+
)
143153
prompt1 = template.format("Write a bubble sort function in python.")
144154
prompt2 = template.format(
145-
"Using the Java streams API, write a simple function which will get the cumulative sum of a list of integers."
155+
"Using the Java streams API, write a simple function which will "
156+
"get the cumulative sum of a list of integers."
146157
)
147158
prompt3 = template.format(
148-
"In bash, how do I list all directories and sub-directories which contain a .py file."
159+
"In bash, how do I list all directories and sub-directories which "
160+
"contain a .py file."
149161
)
150162
prompt4 = template.format(
151-
"Write a simple decorator in python which will modify all string inputs to ints if possible."
163+
"Write a simple decorator in python which will modify all string "
164+
"inputs to ints if possible."
152165
)
153166
else:
154167
dprint("prompt_type must be one of chat or code")
155168
exit()
156169

157-
prompt1 = self.ids_for_prompt(prompt1)
158-
prompt2 = self.ids_for_prompt(prompt2)
159-
prompt3 = self.ids_for_prompt(prompt3)
160-
prompt4 = self.ids_for_prompt(prompt4)
170+
prompt1 = ids_for_prompt(prompt1, self.tokenizer)
171+
prompt2 = ids_for_prompt(prompt2, self.tokenizer)
172+
prompt3 = ids_for_prompt(prompt3, self.tokenizer)
173+
prompt4 = ids_for_prompt(prompt4, self.tokenizer)
161174
prompts = [prompt1, prompt2, prompt3, prompt4]
162175
prompts = prompts * ((args.batch_size // 4) + 1)
163176
prompts = prompts[: args.batch_size]

0 commit comments

Comments
 (0)