Skip to content

Commit 6780770

Browse files
committed
Update type hints
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent 17be9a7 commit 6780770

File tree

2 files changed

+59
-28
lines changed

2 files changed

+59
-28
lines changed

aiu_fms_testing_utils/utils/decoders_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def infer(self, ids, warmup):
260260
extra_generation_kwargs["only_last_token"] = True
261261

262262
if args.device_type == "cpu":
263-
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu impl when pad too much
264263
extra_generation_kwargs["attn_algorithm"] = "math"
265264

266265
if not args.no_early_termination and not warmup:
@@ -338,7 +337,7 @@ def run_decoder_eval(
338337
tokenizer: BaseTokenizer,
339338
args: argparse.Namespace,
340339
device: torch.device,
341-
):
340+
) -> None:
342341
"""Entry point to run evaluation of LLM decoder models."""
343342

344343
decoder_infer = DecoderInfer(model, tokenizer, args, device)

aiu_fms_testing_utils/utils/encoders_utils.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import time
88

99
# Third Party
10-
from datasets import load_dataset
10+
from datasets import Dataset, load_dataset
1111
from fms.models.hf import to_hf_api
12+
from fms.models.hf.modeling_hf_adapter import HFModelArchitecture
1213
from fms.utils import has_package
1314
from fms.utils.tokenizers import BaseTokenizer
1415
from torch import nn
@@ -32,9 +33,15 @@
3233
)
3334

3435

35-
def wrap_encoder(model):
36+
def wrap_encoder(model: nn.Module) -> HFModelArchitecture:
3637
"""Add config info and wrapper to run pipeline for RoBERTa MaskedLM."""
3738

39+
if not has_hf:
40+
raise ImportError(
41+
"MaskedLM Encoder requires transformer package but import "
42+
"was unsuccessful."
43+
)
44+
3845
model.config.linear_config.pop("linear_type", None)
3946
return to_hf_api(model, task_specific_params=None)
4047

@@ -47,9 +54,9 @@ def __init__(
4754
model: nn.Module,
4855
tokenizer: BaseTokenizer,
4956
args: argparse.Namespace,
50-
):
57+
) -> None:
5158
self.model = model
52-
self.tokenizer = tokenizer
59+
self.tokenizer = tokenizer.tokenizer # extract original HF tokenizer
5360
self.args = args
5461

5562
self.question_column_name = ""
@@ -59,7 +66,7 @@ def __init__(
5966

6067
self.validate_encoder_arguments()
6168

62-
def validate_encoder_arguments(self):
69+
def validate_encoder_arguments(self) -> None:
6370
"""Ensure arguments compatibility with Encoder models."""
6471

6572
args = self.args
@@ -85,10 +92,14 @@ def validate_encoder_arguments(self):
8592
)
8693

8794

88-
def prepare_validation_features(self, examples):
95+
def prepare_validation_features(
96+
self,
97+
examples: dict[str, list[str | dict]],
98+
) -> dict[str, list]:
8999
"""Validation preprocessing"""
90100

91101
args = self.args
102+
92103
q_col_name = self.question_column_name
93104
c_col_name = self.context_column_name
94105
pad_on_right = self.pad_on_right
@@ -109,7 +120,7 @@ def prepare_validation_features(self, examples):
109120
# using a stride. This results in one example possible giving several features
110121
# when a context is long, each of those features having a context that overlaps
111122
# a bit the context of the previous feature.
112-
tokenized_examples = self.tokenizer.tokenize(
123+
tokenized_examples = self.tokenizer(
113124
examples[q_col_name if pad_on_right else c_col_name],
114125
examples[c_col_name if pad_on_right else q_col_name],
115126
truncation="only_second" if pad_on_right else "only_first",
@@ -149,12 +160,15 @@ def prepare_validation_features(self, examples):
149160

150161
return tokenized_examples
151162

152-
def convert_batch_to_fms_style(self, batch):
163+
def convert_batch_to_fms_style(
164+
self,
165+
batch: dict[str, torch.Tensor],
166+
) -> dict[str, torch.Tensor]:
153167
"""FMS uses a different standard than HF for encoder inputs."""
154168

155169
return {'x': batch['input_ids'], 'mask': batch['attention_mask']}
156170

157-
def process_eval_set(self):
171+
def process_eval_set(self) -> None:
158172
"""Pre-process evaluation dataset for QuestionAnswering task."""
159173

160174
if not has_hf:
@@ -192,7 +206,7 @@ def process_eval_set(self):
192206
# Padding side determines if we do (question|context) or (context|question)
193207
self.pad_on_right = self.tokenizer.padding_side == "right"
194208

195-
model_max_length = self.tokenizer.tokenizer.model_max_length # TODO: add model_max_length to FMS _HFTokenizer
209+
model_max_length = self.tokenizer.model_max_length
196210
if args.max_prompt_length > model_max_length:
197211
dprint(
198212
f"max_prompt_length ({args.max_prompt_length}) is larger than the "
@@ -259,16 +273,16 @@ def process_eval_set(self):
259273

260274
def postprocess_qa_predictions(
261275
self,
262-
examples,
263-
features,
276+
examples: Dataset,
277+
features: Dataset,
264278
predictions: tuple[np.ndarray, np.ndarray],
265279
version_2_with_negative: bool = False,
266280
n_best_size: int = 20,
267281
max_answer_length: int = 30,
268282
null_score_diff_threshold: float = 0.0,
269283
output_dir: str | None = None,
270284
prefix: str | None = None,
271-
):
285+
) -> None:
272286
"""
273287
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
274288
original contexts. This is the base postprocessing functions for models that only return start and end logits.
@@ -476,7 +490,13 @@ def postprocess_qa_predictions(
476490

477491
return all_predictions
478492

479-
def post_processing_function(self, examples, features, predictions, stage="eval"):
493+
def post_processing_function(
494+
self,
495+
examples: Dataset,
496+
features: Dataset,
497+
predictions: list[np.ndarray],
498+
stage: str = "eval",
499+
) -> dict[list[str, str]]:
480500
"""Post-processing: we match the start logits and end logits to answers in
481501
the original context."""
482502

@@ -492,6 +512,7 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
492512
output_dir=None,
493513
prefix=stage,
494514
)
515+
breakpoint()
495516
# Format the result to the format the metric expects.
496517
if args.version_2_with_negative:
497518
formatted_predictions = [
@@ -508,7 +529,12 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
508529
]
509530
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
510531

511-
def create_and_fill_np_array(self, start_or_end_logits, dataset, max_len):
532+
def create_and_fill_np_array(
533+
self,
534+
start_or_end_logits: list[np.ndarray],
535+
dataset: Dataset,
536+
max_len: int,
537+
) -> np.ndarray:
512538
"""
513539
Create and fill numpy array of size
514540
len_of_validation_data * max_length_of_output_tensor
@@ -543,7 +569,7 @@ def create_and_fill_np_array(self, start_or_end_logits, dataset, max_len):
543569

544570
return logits_concat
545571

546-
def run_warmup(self):
572+
def run_warmup(self) -> None:
547573
"""Run warmup cycle of compiled encoder model set for QuestionAnswering task."""
548574

549575
dprint(f"Starting warm-up...")
@@ -559,7 +585,7 @@ def run_warmup(self):
559585
if rank == 0:
560586
dprint(f"Warmup completed in {time.time() - warmup_start_time:.1f} s\n---")
561587

562-
def run_evaluation(self):
588+
def run_evaluation(self) -> None:
563589
"""Run QuestionAnswering evaluation."""
564590

565591
args = self.args
@@ -587,7 +613,7 @@ def run_evaluation(self):
587613
f"(tot = {len(eval_dataloader) * args.batch_size}, "
588614
f"bs = {args.batch_size})"
589615
)
590-
616+
breakpoint()
591617
# concatenate the numpy array
592618
max_len = max([x.shape[1] for x in all_start_logits])
593619
start_logits_concat = self.create_and_fill_np_array(
@@ -622,21 +648,27 @@ class EncoderMLMInfer():
622648

623649
def __init__(
624650
self,
625-
model: nn.Module,
651+
model: HFModelArchitecture,
626652
tokenizer: BaseTokenizer,
627653
args: argparse.Namespace,
628-
):
654+
) -> None:
629655
self.model = model
630656
self.tokenizer = tokenizer
631657
self.args = args
632658

633659

634-
def process_eval_set(self):
660+
def process_eval_set(self) -> None:
635661
"""Barebone function that sets up a single example prompt (for now)."""
636662

663+
if not has_hf:
664+
raise ImportError(
665+
"MaskedLM Encoder requires transformer package but import "
666+
"was unsuccessful."
667+
)
668+
637669
self.prompt = "the dog chased the cat while<mask> aggressively"
638670

639-
def run_evaluation(self, warmup=False):
671+
def run_evaluation(self, warmup: bool = False) -> None:
640672
"""Run evaluation cycle of compiled encoder model set for MaskedLM task.
641673
No output printout if warmup is True.
642674
"""
@@ -658,10 +690,10 @@ def run_evaluation(self, warmup=False):
658690

659691

660692
def run_encoder_eval_qa(
661-
model: nn.Module,
693+
model: nn.Module, # FMS-style model
662694
tokenizer: BaseTokenizer,
663695
args: argparse.Namespace,
664-
):
696+
) -> None:
665697
"""Entry point to run QuestionAnswering Evaluation of encoder model.
666698
667699
Processing based on pytorch example:
@@ -677,10 +709,10 @@ def run_encoder_eval_qa(
677709

678710

679711
def run_encoder_eval_mlm(
680-
model: nn.Module,
712+
model: HFModelArchitecture, # model wrapped by to_hf_api
681713
tokenizer: BaseTokenizer,
682714
args: argparse.Namespace,
683-
):
715+
) -> None:
684716
"""Entry point to run evaluation of encoder models."""
685717

686718
encoder_mlm_infer = EncoderMLMInfer(model, tokenizer, args)

0 commit comments

Comments
 (0)