7
7
import time
8
8
9
9
# Third Party
10
- from datasets import load_dataset
10
+ from datasets import Dataset , load_dataset
11
11
from fms .models .hf import to_hf_api
12
+ from fms .models .hf .modeling_hf_adapter import HFModelArchitecture
12
13
from fms .utils import has_package
13
14
from fms .utils .tokenizers import BaseTokenizer
14
15
from torch import nn
32
33
)
33
34
34
35
35
- def wrap_encoder (model ) :
36
+ def wrap_encoder (model : nn . Module ) -> HFModelArchitecture :
36
37
"""Add config info and wrapper to run pipeline for RoBERTa MaskedLM."""
37
38
39
+ if not has_hf :
40
+ raise ImportError (
41
+ "MaskedLM Encoder requires transformer package but import "
42
+ "was unsuccessful."
43
+ )
44
+
38
45
model .config .linear_config .pop ("linear_type" , None )
39
46
return to_hf_api (model , task_specific_params = None )
40
47
@@ -47,9 +54,9 @@ def __init__(
47
54
model : nn .Module ,
48
55
tokenizer : BaseTokenizer ,
49
56
args : argparse .Namespace ,
50
- ):
57
+ ) -> None :
51
58
self .model = model
52
- self .tokenizer = tokenizer
59
+ self .tokenizer = tokenizer . tokenizer # extract original HF tokenizer
53
60
self .args = args
54
61
55
62
self .question_column_name = ""
@@ -59,7 +66,7 @@ def __init__(
59
66
60
67
self .validate_encoder_arguments ()
61
68
62
- def validate_encoder_arguments (self ):
69
+ def validate_encoder_arguments (self ) -> None :
63
70
"""Ensure arguments compatibility with Encoder models."""
64
71
65
72
args = self .args
@@ -85,10 +92,14 @@ def validate_encoder_arguments(self):
85
92
)
86
93
87
94
88
- def prepare_validation_features (self , examples ):
95
+ def prepare_validation_features (
96
+ self ,
97
+ examples : dict [str , list [str | dict ]],
98
+ ) -> dict [str , list ]:
89
99
"""Validation preprocessing"""
90
100
91
101
args = self .args
102
+
92
103
q_col_name = self .question_column_name
93
104
c_col_name = self .context_column_name
94
105
pad_on_right = self .pad_on_right
@@ -109,7 +120,7 @@ def prepare_validation_features(self, examples):
109
120
# using a stride. This results in one example possible giving several features
110
121
# when a context is long, each of those features having a context that overlaps
111
122
# a bit the context of the previous feature.
112
- tokenized_examples = self .tokenizer . tokenize (
123
+ tokenized_examples = self .tokenizer (
113
124
examples [q_col_name if pad_on_right else c_col_name ],
114
125
examples [c_col_name if pad_on_right else q_col_name ],
115
126
truncation = "only_second" if pad_on_right else "only_first" ,
@@ -149,12 +160,15 @@ def prepare_validation_features(self, examples):
149
160
150
161
return tokenized_examples
151
162
152
- def convert_batch_to_fms_style (self , batch ):
163
+ def convert_batch_to_fms_style (
164
+ self ,
165
+ batch : dict [str , torch .Tensor ],
166
+ ) -> dict [str , torch .Tensor ]:
153
167
"""FMS uses a different standard than HF for encoder inputs."""
154
168
155
169
return {'x' : batch ['input_ids' ], 'mask' : batch ['attention_mask' ]}
156
170
157
- def process_eval_set (self ):
171
+ def process_eval_set (self ) -> None :
158
172
"""Pre-process evaluation dataset for QuestionAnswering task."""
159
173
160
174
if not has_hf :
@@ -192,7 +206,7 @@ def process_eval_set(self):
192
206
# Padding side determines if we do (question|context) or (context|question)
193
207
self .pad_on_right = self .tokenizer .padding_side == "right"
194
208
195
- model_max_length = self .tokenizer .tokenizer . model_max_length # TODO: add model_max_length to FMS _HFTokenizer
209
+ model_max_length = self .tokenizer .model_max_length
196
210
if args .max_prompt_length > model_max_length :
197
211
dprint (
198
212
f"max_prompt_length ({ args .max_prompt_length } ) is larger than the "
@@ -259,16 +273,16 @@ def process_eval_set(self):
259
273
260
274
def postprocess_qa_predictions (
261
275
self ,
262
- examples ,
263
- features ,
276
+ examples : Dataset ,
277
+ features : Dataset ,
264
278
predictions : tuple [np .ndarray , np .ndarray ],
265
279
version_2_with_negative : bool = False ,
266
280
n_best_size : int = 20 ,
267
281
max_answer_length : int = 30 ,
268
282
null_score_diff_threshold : float = 0.0 ,
269
283
output_dir : str | None = None ,
270
284
prefix : str | None = None ,
271
- ):
285
+ ) -> None :
272
286
"""
273
287
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
274
288
original contexts. This is the base postprocessing functions for models that only return start and end logits.
@@ -476,7 +490,13 @@ def postprocess_qa_predictions(
476
490
477
491
return all_predictions
478
492
479
- def post_processing_function (self , examples , features , predictions , stage = "eval" ):
493
+ def post_processing_function (
494
+ self ,
495
+ examples : Dataset ,
496
+ features : Dataset ,
497
+ predictions : list [np .ndarray ],
498
+ stage : str = "eval" ,
499
+ ) -> dict [list [str , str ]]:
480
500
"""Post-processing: we match the start logits and end logits to answers in
481
501
the original context."""
482
502
@@ -492,6 +512,7 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
492
512
output_dir = None ,
493
513
prefix = stage ,
494
514
)
515
+ breakpoint ()
495
516
# Format the result to the format the metric expects.
496
517
if args .version_2_with_negative :
497
518
formatted_predictions = [
@@ -508,7 +529,12 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
508
529
]
509
530
return EvalPrediction (predictions = formatted_predictions , label_ids = references )
510
531
511
- def create_and_fill_np_array (self , start_or_end_logits , dataset , max_len ):
532
+ def create_and_fill_np_array (
533
+ self ,
534
+ start_or_end_logits : list [np .ndarray ],
535
+ dataset : Dataset ,
536
+ max_len : int ,
537
+ ) -> np .ndarray :
512
538
"""
513
539
Create and fill numpy array of size
514
540
len_of_validation_data * max_length_of_output_tensor
@@ -543,7 +569,7 @@ def create_and_fill_np_array(self, start_or_end_logits, dataset, max_len):
543
569
544
570
return logits_concat
545
571
546
- def run_warmup (self ):
572
+ def run_warmup (self ) -> None :
547
573
"""Run warmup cycle of compiled encoder model set for QuestionAnswering task."""
548
574
549
575
dprint (f"Starting warm-up..." )
@@ -559,7 +585,7 @@ def run_warmup(self):
559
585
if rank == 0 :
560
586
dprint (f"Warmup completed in { time .time () - warmup_start_time :.1f} s\n ---" )
561
587
562
- def run_evaluation (self ):
588
+ def run_evaluation (self ) -> None :
563
589
"""Run QuestionAnswering evaluation."""
564
590
565
591
args = self .args
@@ -587,7 +613,7 @@ def run_evaluation(self):
587
613
f"(tot = { len (eval_dataloader ) * args .batch_size } , "
588
614
f"bs = { args .batch_size } )"
589
615
)
590
-
616
+ breakpoint ()
591
617
# concatenate the numpy array
592
618
max_len = max ([x .shape [1 ] for x in all_start_logits ])
593
619
start_logits_concat = self .create_and_fill_np_array (
@@ -622,21 +648,27 @@ class EncoderMLMInfer():
622
648
623
649
def __init__ (
624
650
self ,
625
- model : nn . Module ,
651
+ model : HFModelArchitecture ,
626
652
tokenizer : BaseTokenizer ,
627
653
args : argparse .Namespace ,
628
- ):
654
+ ) -> None :
629
655
self .model = model
630
656
self .tokenizer = tokenizer
631
657
self .args = args
632
658
633
659
634
- def process_eval_set (self ):
660
+ def process_eval_set (self ) -> None :
635
661
"""Barebone function that sets up a single example prompt (for now)."""
636
662
663
+ if not has_hf :
664
+ raise ImportError (
665
+ "MaskedLM Encoder requires transformer package but import "
666
+ "was unsuccessful."
667
+ )
668
+
637
669
self .prompt = "the dog chased the cat while<mask> aggressively"
638
670
639
- def run_evaluation (self , warmup = False ):
671
+ def run_evaluation (self , warmup : bool = False ) -> None :
640
672
"""Run evaluation cycle of compiled encoder model set for MaskedLM task.
641
673
No output printout if warmup is True.
642
674
"""
@@ -658,10 +690,10 @@ def run_evaluation(self, warmup=False):
658
690
659
691
660
692
def run_encoder_eval_qa (
661
- model : nn .Module ,
693
+ model : nn .Module , # FMS-style model
662
694
tokenizer : BaseTokenizer ,
663
695
args : argparse .Namespace ,
664
- ):
696
+ ) -> None :
665
697
"""Entry point to run QuestionAnswering Evaluation of encoder model.
666
698
667
699
Processing based on pytorch example:
@@ -677,10 +709,10 @@ def run_encoder_eval_qa(
677
709
678
710
679
711
def run_encoder_eval_mlm (
680
- model : nn . Module ,
712
+ model : HFModelArchitecture , # model wrapped by to_hf_api
681
713
tokenizer : BaseTokenizer ,
682
714
args : argparse .Namespace ,
683
- ):
715
+ ) -> None :
684
716
"""Entry point to run evaluation of encoder models."""
685
717
686
718
encoder_mlm_infer = EncoderMLMInfer (model , tokenizer , args )
0 commit comments