3
3
4
4
import torch
5
5
from fms .utils .generation import generate
6
- from aiu_fms_testing_utils .utils import ids_for_prompt
7
6
from aiu_fms_testing_utils .utils .aiu_setup import dprint
8
7
import os
9
8
10
- class LogitsExtractorHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
11
9
10
+ class LogitsExtractorHook (
11
+ Callable [
12
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
13
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
14
+ ]
15
+ ):
12
16
def __init__ (self ):
13
17
super ().__init__ ()
14
18
self .extracted_logits : Optional [torch .Tensor ] = None
15
19
16
- def __call__ (self , token_position : torch .Tensor , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
20
+ def __call__ (
21
+ self ,
22
+ token_position : torch .Tensor ,
23
+ logits : torch .Tensor ,
24
+ next_val : torch .Tensor ,
25
+ kwargs ,
26
+ ):
17
27
if self .extracted_logits is None :
18
28
self .extracted_logits = logits .unsqueeze (1 )
19
29
else :
20
- self .extracted_logits = torch .cat ((self .extracted_logits , logits .unsqueeze (1 )), dim = 1 )
30
+ self .extracted_logits = torch .cat (
31
+ (self .extracted_logits , logits .unsqueeze (1 )), dim = 1
32
+ )
21
33
return next_val , kwargs
22
34
23
- class StaticTokenInjectorHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
24
35
25
- def __init__ (self , static_tokens : List [torch .Tensor ], device_type : str = "cpu" ):
36
+ class StaticTokenInjectorHook (
37
+ Callable [
38
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
39
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
40
+ ]
41
+ ):
42
+ def __init__ (self , static_tokens : List [torch .Tensor ], device_type : str = "cpu" ):
26
43
super ().__init__ ()
27
- self .static_tokens = torch .tensor (static_tokens , device = device_type ).t () # transposing so batch tokens per token_position
44
+ self .static_tokens = torch .tensor (
45
+ static_tokens , device = device_type
46
+ ).t () # transposing so batch tokens per token_position
28
47
29
- def __call__ (self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
48
+ def __call__ (
49
+ self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs
50
+ ):
30
51
next_val .copy_ (self .static_tokens [token_position ].unsqueeze (1 ))
31
52
return next_val , kwargs
32
53
33
- class GoldenTokenHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
34
54
35
- def __init__ (self , static_tokens : torch .Tensor , device_type : str = "cpu" ):
55
+ class GoldenTokenHook (
56
+ Callable [
57
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
58
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
59
+ ]
60
+ ):
61
+ def __init__ (self , static_tokens : torch .Tensor , device_type : str = "cpu" ):
36
62
super ().__init__ ()
37
63
self .logits_extractor = LogitsExtractorHook ()
38
64
self .extracted_logits = None
39
- self .token_injector = StaticTokenInjectorHook (static_tokens , device_type = device_type )
65
+ self .token_injector = StaticTokenInjectorHook (
66
+ static_tokens , device_type = device_type
67
+ )
40
68
41
- def __call__ (self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
42
- next_val , kwargs = self .logits_extractor (token_position , logits , next_val , kwargs )
69
+ def __call__ (
70
+ self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs
71
+ ):
72
+ next_val , kwargs = self .logits_extractor (
73
+ token_position , logits , next_val , kwargs
74
+ )
43
75
self .extracted_logits = self .logits_extractor .extracted_logits
44
76
return self .token_injector (token_position , logits , next_val , kwargs )
45
77
46
- class ValidationInfo :
47
78
79
+ class ValidationInfo :
48
80
def __init__ (self , validation_info_list ):
49
81
super ().__init__ ()
50
82
@@ -55,7 +87,10 @@ def __iter__(self):
55
87
yield vi
56
88
57
89
def get_info (self , info_name ):
58
- return [[t .unsqueeze (0 ) for t in sentence [info_name ]] for sentence in self ._validation_info_list ]
90
+ return [
91
+ [t .unsqueeze (0 ) for t in sentence [info_name ]]
92
+ for sentence in self ._validation_info_list
93
+ ]
59
94
60
95
def save (self , save_dir_path : str ):
61
96
"""Save the validation information into a directory.
@@ -87,12 +122,17 @@ def save(self, save_dir_path: str):
87
122
88
123
def __len__ (self ):
89
124
return len (self ._validation_info_list )
90
-
91
- def get_default_validation_prefix (model_id : str , max_new_tokens : int , batch_size : int , seq_length : int , dtype : str ):
125
+
126
+
127
+ def get_default_validation_prefix (
128
+ model_id : str , max_new_tokens : int , batch_size : int , seq_length : int , dtype : str
129
+ ):
92
130
return f"{ model_id .replace ('/' , '--' )} _max-new-tokens-{ max_new_tokens } _batch-size-{ batch_size } _seq-length-{ seq_length } _dtype-{ dtype } "
93
131
94
132
95
- def load_validation_information (validation_path , validation_files_type , batch_size , tokenizer = None ):
133
+ def load_validation_information (
134
+ validation_path , validation_files_type , batch_size , tokenizer = None
135
+ ):
96
136
"""Load the validation information from a directory
97
137
98
138
The files will be assumed to be in the following structure:
@@ -108,17 +148,15 @@ def load_validation_information(validation_path, validation_files_type, batch_si
108
148
if containing only tokens - torch.tensor
109
149
if containing tokens and logits - dict[tokens -> torch.tensor, logits -> torch.tensor]
110
150
if containing text - str
111
-
151
+
112
152
:param validation_path: path to validation info files
113
153
:param validation_files_type: validation file type to load, one of text, tokens, or logits
114
154
:param batch_size: the number of prompts to load
115
155
:param tokenizer: an optional tokenizer, required when validation_files_type=text
116
156
:return: a new validation info
117
157
"""
118
158
if isinstance (validation_path , str ):
119
- validation_files_path , sep , glob_pattern = validation_path .partition (
120
- "*"
121
- )
159
+ validation_files_path , sep , glob_pattern = validation_path .partition ("*" )
122
160
else :
123
161
sep = ""
124
162
glob_pattern = ""
@@ -147,27 +185,29 @@ def load_validation_information(validation_path, validation_files_type, batch_si
147
185
validation_files_paths = [validation_files_path ]
148
186
149
187
# Check if we found some files
150
- assert (
151
- len ( validation_files_paths ) > 0
152
- ), f"Can't find any validation files at { validation_files_path } "
188
+ assert len ( validation_files_paths ) > 0 , (
189
+ f"Can't find any validation files at { validation_files_path } "
190
+ )
153
191
154
192
# Check if we have enough files
155
- assert (
156
- len ( validation_files_paths ) >= batch_size
157
- ), f"Not enough validation files at { validation_files_path } for a batch size of { batch_size } "
193
+ assert len ( validation_files_paths ) >= batch_size , (
194
+ f"Not enough validation files at { validation_files_path } for a batch size of { batch_size } "
195
+ )
158
196
159
197
validation_info = []
160
198
for i , validation_file_path in enumerate (validation_files_paths ):
161
199
if i == batch_size :
162
200
break
163
201
if validation_files_type == "text" :
164
202
if tokenizer is None :
165
- raise ValueError ("must provide a tokenizer when validation_files_type=text" )
203
+ raise ValueError (
204
+ "must provide a tokenizer when validation_files_type=text"
205
+ )
166
206
# Text format will get tokenized
167
207
validation_info .append (
168
208
{
169
- "tokens" : ids_for_prompt (
170
- validation_file_path .read_text (encoding = "utf-8" ), tokenizer
209
+ "tokens" : tokenizer . encode (
210
+ validation_file_path .read_text (encoding = "utf-8" ), return_tensors = "pt"
171
211
),
172
212
"logits" : None ,
173
213
}
@@ -188,7 +228,18 @@ def load_validation_information(validation_path, validation_files_type, batch_si
188
228
189
229
return ValidationInfo (validation_info )
190
230
191
- def extract_validation_information (model , input_ids , max_new_tokens , post_iteration_hook , attn_algorithm = None , eos_token_id = None , only_last_token = False , timing = "" , ** padding_kwargs ):
231
+
232
+ def extract_validation_information (
233
+ model ,
234
+ input_ids ,
235
+ max_new_tokens ,
236
+ post_iteration_hook ,
237
+ attn_algorithm = None ,
238
+ eos_token_id = None ,
239
+ only_last_token = False ,
240
+ timing = "" ,
241
+ ** padding_kwargs ,
242
+ ):
192
243
max_seq_len = model .config .max_expected_seq_len
193
244
194
245
# Add only_last_token optimization
@@ -220,7 +271,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
220
271
if timing == "e2e" :
221
272
dprint (f"E2E timing information: { timings [0 ]:.3f} s" )
222
273
elif timing == "per-token" :
223
- timings = [f"{ t * 1000 :.3f} " for t in timings ]
274
+ timings = [f"{ t * 1000 :.3f} " for t in timings ]
224
275
dprint (f"Per-token timing information: { ', ' .join (timings )} ms" )
225
276
226
277
if len (result .shape ) == 1 :
@@ -229,75 +280,88 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
229
280
if hasattr (post_iteration_hook , "extracted_logits" ):
230
281
validation_info = [
231
282
{"tokens" : t .to ("cpu" ), "logits" : l .to ("cpu" )}
232
- for t , l in zip (torch .unbind (result ), torch .unbind (post_iteration_hook .extracted_logits ))
283
+ for t , l in zip (
284
+ torch .unbind (result ), torch .unbind (post_iteration_hook .extracted_logits )
285
+ )
233
286
]
234
287
else :
235
288
validation_info = [{"tokens" : t .to ("cpu" )} for t in torch .unbind (result )]
236
289
return ValidationInfo (validation_info )
237
290
291
+
238
292
def validate_level_0 (aiu_tokens_per_sentence , validation_tokens_per_sentence ):
239
293
failed_cases = []
240
294
241
295
for sentence_idx , (aiu_sentence , validation_sentence ) in enumerate (
242
- zip (aiu_tokens_per_sentence , validation_tokens_per_sentence )
296
+ zip (aiu_tokens_per_sentence , validation_tokens_per_sentence )
243
297
):
244
298
for token_idx , (aiu_token , validation_token ) in enumerate (
245
- zip (aiu_sentence , validation_sentence )
299
+ zip (aiu_sentence , validation_sentence )
246
300
):
247
301
if aiu_token != validation_token :
248
302
failed_cases .append ((sentence_idx , token_idx ))
249
303
return failed_cases
250
304
251
- def top_k_loss_calculator (top_k : int , loss_f : Callable [[torch .Tensor , torch .Tensor ], float ]):
305
+
306
+ def top_k_loss_calculator (
307
+ top_k : int , loss_f : Callable [[torch .Tensor , torch .Tensor ], float ]
308
+ ):
252
309
"""
253
310
Function which will take the top_k logits indexes / values from a reference validation info and retrieve the same indexes from the test validation info logits
254
311
and perform a loss function over the 2 tensors
255
312
256
313
:param top_k: number of values to take from reference
257
314
:param loss_f: a loss function between the reference and test logits
258
315
"""
316
+
259
317
def loss_func (reference_logits , test_logits ):
260
318
reference_logits_prob = reference_logits .to (dtype = torch .float32 )
261
319
test_logits_prob = test_logits .to (dtype = torch .float32 )
262
320
263
- reference_values , reference_indices = torch .topk (reference_logits_prob , top_k , dim = 1 )
321
+ reference_values , reference_indices = torch .topk (
322
+ reference_logits_prob , top_k , dim = 1
323
+ )
264
324
test_values = test_logits_prob [:, reference_indices .squeeze (0 )]
265
325
266
326
return loss_f (reference_values , test_values )
327
+
267
328
return loss_func
268
329
269
330
270
- def capture_level_1_metrics (reference_logits_per_sentence , test_logits_per_sentence , metrics_calculator = None ):
331
+ def capture_level_1_metrics (
332
+ reference_logits_per_sentence , test_logits_per_sentence , metrics_calculator = None
333
+ ):
271
334
loss_metrics = []
272
335
273
336
for sentence_idx , (reference_sentence , test_sentence ) in enumerate (
274
- zip (reference_logits_per_sentence , test_logits_per_sentence )
337
+ zip (reference_logits_per_sentence , test_logits_per_sentence )
275
338
):
276
339
for token_idx , (reference_logits , test_logits ) in enumerate (
277
- zip (reference_sentence , test_sentence )
340
+ zip (reference_sentence , test_sentence )
278
341
):
279
342
# computing cross entropy loss per token
280
343
if metrics_calculator is None :
281
344
loss_fn = torch .nn .CrossEntropyLoss ()
282
345
metrics_value = loss_fn (
283
346
reference_logits .to (dtype = torch .float32 ),
284
- test_logits .softmax (dim = 1 ).to (dtype = torch .float32 )
347
+ test_logits .softmax (dim = 1 ).to (dtype = torch .float32 ),
285
348
)
286
349
else :
287
350
metrics_value = metrics_calculator (reference_logits , test_logits )
288
351
289
352
loss_metrics .append ((sentence_idx , token_idx , metrics_value ))
290
353
291
354
return loss_metrics
292
-
355
+
356
+
293
357
def filter_failed_level_1_cases (level_1_loss_metrics , fail_f , print_failed = False ):
294
358
failed_cases = []
295
- for ( sentence_idx , token_idx , metrics_value ) in level_1_loss_metrics :
359
+ for sentence_idx , token_idx , metrics_value in level_1_loss_metrics :
296
360
if fail_f (metrics_value ):
297
361
failed_cases .append ((sentence_idx , token_idx , metrics_value ))
298
362
if print_failed :
299
363
dprint (
300
- f"In sentence { sentence_idx + 1 } , the metric for token { token_idx } is { metrics_value } "
364
+ f"In sentence { sentence_idx + 1 } , the metric for token { token_idx } is { metrics_value } "
301
365
)
302
366
return failed_cases
303
367
@@ -307,6 +371,8 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
307
371
aiu_token = aiu_tokens [sentence_index ][token_index ]
308
372
validation_token = validation_tokens [sentence_index ][token_index ]
309
373
310
- aiu_str = tokenizer .convert_tokens_to_string (tokenizer .convert_ids_to_tokens (aiu_token ))
311
- validation_str = tokenizer .convert_tokens_to_string (tokenizer .convert_ids_to_tokens (validation_token ))
312
- print (f"In sentence { sentence_index + 1 } /{ len (aiu_tokens )} , token { token_index } , AIU outputs { aiu_token } instead of { validation_token } -- AIU val={ aiu_str } -- CPU val={ validation_str } " )
374
+ aiu_str = tokenizer .decode (aiu_token )
375
+ validation_str = tokenizer .decode (validation_token )
376
+ print (
377
+ f"In sentence { sentence_index + 1 } /{ len (aiu_tokens )} , token { token_index } , AIU outputs { aiu_token } instead of { validation_token } -- AIU val={ aiu_str } -- CPU val={ validation_str } "
378
+ )
0 commit comments