Skip to content

Commit 67aece7

Browse files
committed
format and use hf tokenizer api
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent d77e570 commit 67aece7

File tree

15 files changed

+733
-378
lines changed

15 files changed

+733
-378
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 114 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,80 @@
33

44
import torch
55
from fms.utils.generation import generate
6-
from aiu_fms_testing_utils.utils import ids_for_prompt
76
from aiu_fms_testing_utils.utils.aiu_setup import dprint
87
import os
98

10-
class LogitsExtractorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
119

10+
class LogitsExtractorHook(
11+
Callable[
12+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
13+
Tuple[torch.Tensor, MutableMapping[str, Any]],
14+
]
15+
):
1216
def __init__(self):
1317
super().__init__()
1418
self.extracted_logits: Optional[torch.Tensor] = None
1519

16-
def __call__(self, token_position: torch.Tensor, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
20+
def __call__(
21+
self,
22+
token_position: torch.Tensor,
23+
logits: torch.Tensor,
24+
next_val: torch.Tensor,
25+
kwargs,
26+
):
1727
if self.extracted_logits is None:
1828
self.extracted_logits = logits.unsqueeze(1)
1929
else:
20-
self.extracted_logits = torch.cat((self.extracted_logits, logits.unsqueeze(1)), dim=1)
30+
self.extracted_logits = torch.cat(
31+
(self.extracted_logits, logits.unsqueeze(1)), dim=1
32+
)
2133
return next_val, kwargs
2234

23-
class StaticTokenInjectorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
2435

25-
def __init__(self, static_tokens: List[torch.Tensor], device_type: str="cpu"):
36+
class StaticTokenInjectorHook(
37+
Callable[
38+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
39+
Tuple[torch.Tensor, MutableMapping[str, Any]],
40+
]
41+
):
42+
def __init__(self, static_tokens: List[torch.Tensor], device_type: str = "cpu"):
2643
super().__init__()
27-
self.static_tokens = torch.tensor(static_tokens, device=device_type).t() # transposing so batch tokens per token_position
44+
self.static_tokens = torch.tensor(
45+
static_tokens, device=device_type
46+
).t() # transposing so batch tokens per token_position
2847

29-
def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
48+
def __call__(
49+
self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs
50+
):
3051
next_val.copy_(self.static_tokens[token_position].unsqueeze(1))
3152
return next_val, kwargs
3253

33-
class GoldenTokenHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
3454

35-
def __init__(self, static_tokens: torch.Tensor, device_type: str="cpu"):
55+
class GoldenTokenHook(
56+
Callable[
57+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
58+
Tuple[torch.Tensor, MutableMapping[str, Any]],
59+
]
60+
):
61+
def __init__(self, static_tokens: torch.Tensor, device_type: str = "cpu"):
3662
super().__init__()
3763
self.logits_extractor = LogitsExtractorHook()
3864
self.extracted_logits = None
39-
self.token_injector = StaticTokenInjectorHook(static_tokens, device_type=device_type)
65+
self.token_injector = StaticTokenInjectorHook(
66+
static_tokens, device_type=device_type
67+
)
4068

41-
def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
42-
next_val, kwargs = self.logits_extractor(token_position, logits, next_val, kwargs)
69+
def __call__(
70+
self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs
71+
):
72+
next_val, kwargs = self.logits_extractor(
73+
token_position, logits, next_val, kwargs
74+
)
4375
self.extracted_logits = self.logits_extractor.extracted_logits
4476
return self.token_injector(token_position, logits, next_val, kwargs)
4577

46-
class ValidationInfo:
4778

79+
class ValidationInfo:
4880
def __init__(self, validation_info_list):
4981
super().__init__()
5082

@@ -55,7 +87,10 @@ def __iter__(self):
5587
yield vi
5688

5789
def get_info(self, info_name):
58-
return [[t.unsqueeze(0) for t in sentence[info_name]] for sentence in self._validation_info_list]
90+
return [
91+
[t.unsqueeze(0) for t in sentence[info_name]]
92+
for sentence in self._validation_info_list
93+
]
5994

6095
def save(self, save_dir_path: str):
6196
"""Save the validation information into a directory.
@@ -87,12 +122,17 @@ def save(self, save_dir_path: str):
87122

88123
def __len__(self):
89124
return len(self._validation_info_list)
90-
91-
def get_default_validation_prefix(model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str):
125+
126+
127+
def get_default_validation_prefix(
128+
model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str
129+
):
92130
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}"
93131

94132

95-
def load_validation_information(validation_path, validation_files_type, batch_size, tokenizer=None):
133+
def load_validation_information(
134+
validation_path, validation_files_type, batch_size, tokenizer=None
135+
):
96136
"""Load the validation information from a directory
97137
98138
The files will be assumed to be in the following structure:
@@ -108,17 +148,15 @@ def load_validation_information(validation_path, validation_files_type, batch_si
108148
if containing only tokens - torch.tensor
109149
if containing tokens and logits - dict[tokens -> torch.tensor, logits -> torch.tensor]
110150
if containing text - str
111-
151+
112152
:param validation_path: path to validation info files
113153
:param validation_files_type: validation file type to load, one of text, tokens, or logits
114154
:param batch_size: the number of prompts to load
115155
:param tokenizer: an optional tokenizer, required when validation_files_type=text
116156
:return: a new validation info
117157
"""
118158
if isinstance(validation_path, str):
119-
validation_files_path, sep, glob_pattern = validation_path.partition(
120-
"*"
121-
)
159+
validation_files_path, sep, glob_pattern = validation_path.partition("*")
122160
else:
123161
sep = ""
124162
glob_pattern = ""
@@ -147,27 +185,29 @@ def load_validation_information(validation_path, validation_files_type, batch_si
147185
validation_files_paths = [validation_files_path]
148186

149187
# Check if we found some files
150-
assert (
151-
len(validation_files_paths) > 0
152-
), f"Can't find any validation files at {validation_files_path}"
188+
assert len(validation_files_paths) > 0, (
189+
f"Can't find any validation files at {validation_files_path}"
190+
)
153191

154192
# Check if we have enough files
155-
assert (
156-
len(validation_files_paths) >= batch_size
157-
), f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}"
193+
assert len(validation_files_paths) >= batch_size, (
194+
f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}"
195+
)
158196

159197
validation_info = []
160198
for i, validation_file_path in enumerate(validation_files_paths):
161199
if i == batch_size:
162200
break
163201
if validation_files_type == "text":
164202
if tokenizer is None:
165-
raise ValueError("must provide a tokenizer when validation_files_type=text")
203+
raise ValueError(
204+
"must provide a tokenizer when validation_files_type=text"
205+
)
166206
# Text format will get tokenized
167207
validation_info.append(
168208
{
169-
"tokens": ids_for_prompt(
170-
validation_file_path.read_text(encoding="utf-8"), tokenizer
209+
"tokens": tokenizer.encode(
210+
validation_file_path.read_text(encoding="utf-8"), return_tensors="pt"
171211
),
172212
"logits": None,
173213
}
@@ -188,7 +228,18 @@ def load_validation_information(validation_path, validation_files_type, batch_si
188228

189229
return ValidationInfo(validation_info)
190230

191-
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **padding_kwargs):
231+
232+
def extract_validation_information(
233+
model,
234+
input_ids,
235+
max_new_tokens,
236+
post_iteration_hook,
237+
attn_algorithm=None,
238+
eos_token_id=None,
239+
only_last_token=False,
240+
timing="",
241+
**padding_kwargs,
242+
):
192243
max_seq_len = model.config.max_expected_seq_len
193244

194245
# Add only_last_token optimization
@@ -220,7 +271,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
220271
if timing == "e2e":
221272
dprint(f"E2E timing information: {timings[0]:.3f}s")
222273
elif timing == "per-token":
223-
timings = [f"{t*1000:.3f}" for t in timings]
274+
timings = [f"{t * 1000:.3f}" for t in timings]
224275
dprint(f"Per-token timing information: {', '.join(timings)} ms")
225276

226277
if len(result.shape) == 1:
@@ -229,75 +280,88 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
229280
if hasattr(post_iteration_hook, "extracted_logits"):
230281
validation_info = [
231282
{"tokens": t.to("cpu"), "logits": l.to("cpu")}
232-
for t, l in zip(torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits))
283+
for t, l in zip(
284+
torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits)
285+
)
233286
]
234287
else:
235288
validation_info = [{"tokens": t.to("cpu")} for t in torch.unbind(result)]
236289
return ValidationInfo(validation_info)
237290

291+
238292
def validate_level_0(aiu_tokens_per_sentence, validation_tokens_per_sentence):
239293
failed_cases = []
240294

241295
for sentence_idx, (aiu_sentence, validation_sentence) in enumerate(
242-
zip(aiu_tokens_per_sentence, validation_tokens_per_sentence)
296+
zip(aiu_tokens_per_sentence, validation_tokens_per_sentence)
243297
):
244298
for token_idx, (aiu_token, validation_token) in enumerate(
245-
zip(aiu_sentence, validation_sentence)
299+
zip(aiu_sentence, validation_sentence)
246300
):
247301
if aiu_token != validation_token:
248302
failed_cases.append((sentence_idx, token_idx))
249303
return failed_cases
250304

251-
def top_k_loss_calculator(top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float]):
305+
306+
def top_k_loss_calculator(
307+
top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float]
308+
):
252309
"""
253310
Function which will take the top_k logits indexes / values from a reference validation info and retrieve the same indexes from the test validation info logits
254311
and perform a loss function over the 2 tensors
255312
256313
:param top_k: number of values to take from reference
257314
:param loss_f: a loss function between the reference and test logits
258315
"""
316+
259317
def loss_func(reference_logits, test_logits):
260318
reference_logits_prob = reference_logits.to(dtype=torch.float32)
261319
test_logits_prob = test_logits.to(dtype=torch.float32)
262320

263-
reference_values, reference_indices = torch.topk(reference_logits_prob, top_k, dim=1)
321+
reference_values, reference_indices = torch.topk(
322+
reference_logits_prob, top_k, dim=1
323+
)
264324
test_values = test_logits_prob[:, reference_indices.squeeze(0)]
265325

266326
return loss_f(reference_values, test_values)
327+
267328
return loss_func
268329

269330

270-
def capture_level_1_metrics(reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None):
331+
def capture_level_1_metrics(
332+
reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None
333+
):
271334
loss_metrics = []
272335

273336
for sentence_idx, (reference_sentence, test_sentence) in enumerate(
274-
zip(reference_logits_per_sentence, test_logits_per_sentence)
337+
zip(reference_logits_per_sentence, test_logits_per_sentence)
275338
):
276339
for token_idx, (reference_logits, test_logits) in enumerate(
277-
zip(reference_sentence, test_sentence)
340+
zip(reference_sentence, test_sentence)
278341
):
279342
# computing cross entropy loss per token
280343
if metrics_calculator is None:
281344
loss_fn = torch.nn.CrossEntropyLoss()
282345
metrics_value = loss_fn(
283346
reference_logits.to(dtype=torch.float32),
284-
test_logits.softmax(dim=1).to(dtype=torch.float32)
347+
test_logits.softmax(dim=1).to(dtype=torch.float32),
285348
)
286349
else:
287350
metrics_value = metrics_calculator(reference_logits, test_logits)
288351

289352
loss_metrics.append((sentence_idx, token_idx, metrics_value))
290353

291354
return loss_metrics
292-
355+
356+
293357
def filter_failed_level_1_cases(level_1_loss_metrics, fail_f, print_failed=False):
294358
failed_cases = []
295-
for (sentence_idx, token_idx, metrics_value) in level_1_loss_metrics:
359+
for sentence_idx, token_idx, metrics_value in level_1_loss_metrics:
296360
if fail_f(metrics_value):
297361
failed_cases.append((sentence_idx, token_idx, metrics_value))
298362
if print_failed:
299363
dprint(
300-
f"In sentence {sentence_idx+1}, the metric for token {token_idx} is {metrics_value}"
364+
f"In sentence {sentence_idx + 1}, the metric for token {token_idx} is {metrics_value}"
301365
)
302366
return failed_cases
303367

@@ -307,6 +371,8 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
307371
aiu_token = aiu_tokens[sentence_index][token_index]
308372
validation_token = validation_tokens[sentence_index][token_index]
309373

310-
aiu_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(aiu_token))
311-
validation_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(validation_token))
312-
print(f"In sentence {sentence_index+1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}")
374+
aiu_str = tokenizer.decode(aiu_token)
375+
validation_str = tokenizer.decode(validation_token)
376+
print(
377+
f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}"
378+
)

0 commit comments

Comments
 (0)