From 4d57dd06c70a382c19e9e18836b2f8fdeef08c78 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 26 Jun 2025 21:59:11 -0400 Subject: [PATCH 01/22] initial encoder refactoring (wip) Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/aiu_setup.py | 60 +- aiu_fms_testing_utils/utils/args_parsing.py | 329 ++++++++ aiu_fms_testing_utils/utils/encoders_utils.py | 723 ++++++++++++++++++ aiu_fms_testing_utils/utils/model_setup.py | 135 ++++ .../utils/quantization_setup.py | 126 +++ scripts/encoders_inference.py | 98 +++ 6 files changed, 1467 insertions(+), 4 deletions(-) create mode 100644 aiu_fms_testing_utils/utils/args_parsing.py create mode 100644 aiu_fms_testing_utils/utils/encoders_utils.py create mode 100644 aiu_fms_testing_utils/utils/model_setup.py create mode 100644 aiu_fms_testing_utils/utils/quantization_setup.py create mode 100644 scripts/encoders_inference.py diff --git a/aiu_fms_testing_utils/utils/aiu_setup.py b/aiu_fms_testing_utils/utils/aiu_setup.py index fb9a3df..d7ee71f 100644 --- a/aiu_fms_testing_utils/utils/aiu_setup.py +++ b/aiu_fms_testing_utils/utils/aiu_setup.py @@ -1,4 +1,6 @@ +import argparse import os +import torch # ============================================================== # Common utilities @@ -21,7 +23,7 @@ def dprint(text): # ============================================================== # Common setup # ============================================================== -def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): +def aiu_setup(rank=0, world_size=1, local_rank=0, verbose=False): # ------------- # Envar setup for Sentient backend # ------------- @@ -54,11 +56,9 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): # ============================================================== # Distributed setup # ============================================================== -def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False): +def aiu_dist_setup(rank, world_size, local_rank=-0, verbose=False): if local_rank < 0: local_rank = rank - if local_size < 0: - local_size = world_size if os.getenv("TORCHELASTIC_RUN_ID") is None: os.environ["MASTER_ADDR"] = "localhost" @@ -67,3 +67,55 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False dprint(f"Detected running via torchrun") aiu_setup(rank, world_size) + + +# ============================================================== +# Environment variables utilities +# ============================================================== +def set_aiu_env_vars(args: argparse.Namespace) -> None: + """Set necessary environment variables for AIU""" + + if not args.compile_dynamic: + _target_cache_size = max( + int(args.max_new_tokens * 2), + int(args.min_pad_length * 2.5), + int(args.fixed_prompt_length * 2.5), + ) + _prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length)) + if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"): + if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit: + _prev = torch._dynamo.config.accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size + dprint( + "NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit " + f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} " + f"to accomodate prompt size of {_prompt_size} and decode tokens of " + f"{args.max_new_tokens}" + ) + + if _target_cache_size > torch._dynamo.config.cache_size_limit: + _prev = torch._dynamo.config.cache_size_limit + torch._dynamo.config.cache_size_limit = _target_cache_size + dprint( + f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to " + f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of " + f"{_prompt_size} and decode tokens of {args.max_new_tokens}" + ) + + torch._dynamo.config.assume_static_by_default = True + torch._dynamo.config.automatic_dynamic_shapes = False + + # os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") # CONFIRM IF THIS IS NEEDE + + if not args.is_encoder: + os.environ.setdefault("COMPILATION_MODE", "offline_decoder") + + if args.device_type == "aiu-senulator": + os.environ["FLEX_COMPUTE"] = "SENULATOR" + os.environ["FLEX_DEVICE"] = "MOCK" + else: + if "AIU_WORLD_RANK_0" not in os.environ: + print("must set AIU_WORLD_RANK_0") + exit() + os.environ.setdefault("FLEX_COMPUTE", "SENTIENT") + os.environ.setdefault("FLEX_DEVICE", "PF") # will use VF eventually diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py new file mode 100644 index 0000000..59278a5 --- /dev/null +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -0,0 +1,329 @@ +# Standard +import argparse + +# Local Packages +from aiu_fms_testing_utils.utils.aiu_setup import dprint + + +def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: + + # Arguments for FMS model loading + args_model_loading = parser.add_argument_group("FMS model loading") + args_model_loading.add_argument( + "--architecture", + type=str, + help="The model architecture to benchmark", + ) + args_model_loading.add_argument( + "--variant", + type=str, + default=None, + help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.", + ) + args_model_loading.add_argument( + "--model_path", + type=str, + help=( + "Path to the directory containing LLaMa weights " + "(.pth files sharded by tensor parallel rank, not HF weights)" + ), + ) + args_model_loading.add_argument( + "--model_source", + type=str, + help="Source of the checkpoint. E.g. 'meta', 'hf', None", + ) + args_model_loading.add_argument( + "--unfuse_weights", + action="store_true", + help=( + "If set to True, this will unfuse any fused weight modules" + ), + ) + args_model_loading.add_argument( + "--default_dtype", + type=str, + default=None, + choices=["bf16", "fp16", "fp32"], + help=( + "If set to one of the choices, overrides the model checkpoint " + "weight format by setting the default pytorch format" + ), + ) + + # Quantization arguments + args_quantization = parser.add_argument_group("Model quantization") + args_quantization.add_argument( + "--quantization", + type=str, + choices=["gptq", "fp8"], # TODO: add "fp8" when available in FMS + default=None, + help="Type of quantization of the model checkpoint", + ) + + # General run settings + args_run_settings = parser.add_argument_group("Run settings") + args_run_settings.add_argument( + "--device_type", + type=str, + choices=["cuda", "cpu", "aiu", "aiu-senulator"], + default="cuda", + help="The device to run the model on" + ) + args_run_settings.add_argument( + "--seed", + type=int, + default=81072, + help="Fix run seed for reproducibility", + ) + args_run_settings.add_argument( + "--output_path", + type=str, + default="", + help="path of folder to save outputs to, if empty don't save", + ) + args_run_settings.add_argument( + "--tokenizer", + type=str, + required=True, + help="Path to the tokenizer (e.g. ~/tokenizer.model)", + ) + args_run_settings.add_argument( + "--no_use_cache", + action="store_false", + help="Disable the kv-cache (on by default)", + ) + args_run_settings.add_argument( + "--deterministic", + action="store_true", + help=( + "`deterministic` requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`" + " when running on CPU or GPU. This flag is ignored on AIU." + ), + ) + args_run_settings.add_argument( + "--distributed", + action="store_true", + help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", + ) + args_run_settings.add_argument( # could be a bool / flag + '-v', '--verbose', + action='count', + default=0, + help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" + ) + + # Arguments for compilation + args_compile = parser.add_argument_group("Compiler") + args_compile.add_argument( + "--compile", + action="store_true", + help="Use torch.compile (slow for first inference pass)", + ) + args_compile.add_argument( + "--compile_mode", + type=str, + help="Mode for compilation (only valid for inductor backend)", + default="default", + choices=["default", "reduce-overhead"], + ) + args_compile.add_argument( + "--compile_backend", + type=str, + help="Backend for compilation (only when not running on AIU)", + default="inductor", + choices=["inductor", "eager", "aot_eager"], + ) + args_compile.add_argument( + "--compile_dynamic", + action="store_true", + help="Use dynamic shapes with torch.compile", + ) + args_compile.add_argument( + "--compile_dynamic_sendnn", + action="store_true", + help="Use dynamic shapes with aiu compile", + ) + + # Arguments shared between Decoder and Encoder models + args_dec_enc = parser.add_argument_group("Decoders or Encoders (shared args)") + args_dec_enc.add_argument( + "--batch_size", + type=int, + default=1, + help="size of input batch", + ) + args_dec_enc.add_argument( + "--max_prompt_length", + type=int, + default=None, + help=( + "Cap the number of tokens per prompt to a maximum length prior to padding. " + "If None, prompts to decoder models will have no cap, while prompts to " + "encoder models will be capped to a default of 384 tokens." + ), + ) + + # Decoder model arguments + # args_decoder = parser.add_argument_group("Decoders") + # args_decoder.add_argument( + # "--min_pad_length", + # type=int, + # default=0, + # help=( + # "Pad inputs to a minimum specified length. If any prompt is larger than " + # "the specified length, padding will be determined by the largest prompt" + # ), + # ) + # args_decoder.add_argument( + # "--fixed_prompt_length", + # type=int, + # default=0, + # help=( + # "If defined, overrides both min_pad_length and max_prompt_length. " + # "Pads input to fixed_prompt_length, fails if any input needs truncation." + # ), + # ) + # args_decoder.add_argument( + # "--max_new_tokens", + # type=int, + # help="max number of generated tokens", + # default=100, + # ) + # args_decoder.add_argument( + # "--no_early_termination", + # action="store_true", + # help="disable early termination on generation", + # ) + # args_decoder.add_argument( + # "--prompt_type", + # type=str, + # choices=["chat", "code"], + # default="chat", + # help="type of prompts to be used, either chat or code", + # ) + # args_decoder.add_argument( + # "--prompt_path", + # type=str, + # default="", + # help=( + # "If set, load the prompts from file(s) instead of the local examples. " + # "Supports glob-style patterns" + # ), + # ) + # args_decoder.add_argument( + # "--timing", + # type=str, + # choices=["e2e", "per-token"], + # default="", + # help="if set, how to time the generation of tokens, e2e or per-token", + # ) + # args_decoder.add_argument( + # "--iters", + # type=int, + # default=1, + # help=( + # "Number of iterations of inference to perform. Used for variance " + # "performance capture." + # ), + # ) + + # Encoder model arguments + args_encoder = parser.add_argument_group("Encoders") + args_encoder.add_argument( + "--dataset_name", + type=str, + default="squad_v2", + help="The name of the dataset to use (via the datasets library).", + ) + args_encoder.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + args_encoder.add_argument( + "--n_best_size", + type=int, + default=20, + help="Total number of n-best predictions to generate.", + ) + args_encoder.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help=( + "The threshold used to select the null answer: if the best answer has a " + "score that is less than the score of the null answer minus this threshold, " + "the null answer is selected for this example. Only useful when " + "`version_2_with_negative=True`." + ), + ) + args_encoder.add_argument( + "--version_2_with_negative", + type=bool, + default=True, + help="If true, some of the examples do not have an answer.", + ) + args_encoder.add_argument( + "--max_answer_length", + type=int, + default=30, + help=( + "The maximum length of an answer that can be generated. This is needed " + "because the start and end predictions are not conditioned on one another." + ), + ) + args_encoder.add_argument( + "--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.", + ) + args_encoder.add_argument( + "--pad_to_max_length", + action="store_true", + help=( + "If passed, pad all samples to `max_seq_length`. " + "Otherwise, dynamic padding is used." + ), + ) + args_encoder.add_argument( + "--max_eval_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of " + "evaluation examples to this value if set." + ), + ) + args_encoder.add_argument( + "--preprocessing_num_workers", type=int, default=1, help="" + ) + args_encoder.add_argument( + "--overwrite_cache", + action="store_true", + help="Overwrite the cached training and evaluation sets", + ) + args_encoder.add_argument( + "--doc_stride", + type=int, + default=128, + help=( + "When splitting up a long document into chunks how much stride " + "to take between chunks." + ), + ) + args = parser.parse_args() + + # Add convenient arguments to parser + args.is_quantized = args.quantization is not None + args.is_aiu_backend = "aiu" in args.device_type + args.dynamo_backend = "sendnn" if args.is_aiu_backend else "inductor" + args.fused_weights = not args.unfuse_weights + + if args.verbose: + dprint("=" * 60) + dprint(args) + dprint("=" * 60) + return args diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py new file mode 100644 index 0000000..52c700d --- /dev/null +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -0,0 +1,723 @@ +# Standard +from tqdm import tqdm +import argparse +import collections +import json +import os +import time + +# Third Party +from datasets import Dataset, load_dataset +from fms.models.hf import to_hf_api +from fms.models.hf.modeling_hf_adapter import HFModelArchitecture +from fms.utils import has_package +from fms.utils.tokenizers import BaseTokenizer +from torch import nn +from torch.utils.data import DataLoader +import evaluate +import numpy as np +import torch + +# Local Packages +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank + + +# Optional imports (required for QA) +has_hf = has_package("transformers") +if has_hf: + from transformers import ( + default_data_collator, + DataCollatorWithPadding, + EvalPrediction, + pipeline, + ) + + +def wrap_encoder(model: nn.Module) -> HFModelArchitecture: + """Add config info and wrapper to run pipeline for RoBERTa MaskedLM.""" + + if not has_hf: + raise ImportError( + "MaskedLM Encoder requires transformer package but import " + "was unsuccessful." + ) + + model.config.linear_config.pop("linear_type", None) + return to_hf_api(model, task_specific_params=None) + + +class EncoderQAInfer(): + """Run QuestionAnswering task with encoder models.""" + + def __init__( + self, + model: nn.Module, + tokenizer: BaseTokenizer, + args: argparse.Namespace, + ) -> None: + self.model = model + self.tokenizer = tokenizer.tokenizer # extract original HF tokenizer + self.args = args + + self.question_column_name = "" + self.context_column_name = "" + self.answer_column_name = "" + self.pad_on_right = True + + self.validate_encoder_arguments() + + def validate_encoder_arguments(self) -> None: + """Ensure arguments compatibility with Encoder models.""" + + args = self.args + if not getattr(args, "is_encoder", False): + raise ValueError( + "Running encoder model but is_encoder argument is not set to True. " + "Verify your launch script." + ) + if args.min_pad_length != 0: + raise ValueError( + "Argument min_pad_length should not be provided to encoders. " + "To pad the input sequence, use --pad_to_max_length flag instead." + ) + if args.fixed_prompt_length != 0: + raise ValueError( + "Argument fixed_prompt_length should not be provided to encoders. " + "To pad the input sequence, use --pad_to_max_length flag instead." + ) + if args.max_new_tokens != 100: # default value for decoder models + raise ValueError( + "Argument max_new_token should not be provided to encoders. " + "To define the max length of a generated answer in QuestionAnswering " + "use --max_answer_length instead." + ) + + + def prepare_validation_features( + self, + examples: dict[str, list[str | dict]], + ) -> dict[str, list]: + """Validation preprocessing""" + + args = self.args + + q_col_name = self.question_column_name + c_col_name = self.context_column_name + pad_on_right = self.pad_on_right + max_prompt_length = ( + args.max_prompt_length + if args.max_prompt_length is not None + else 384 + ) + + # Some of the questions have lots of whitespace on the left, which is not useful + # and will make the truncation of the context fail (the tokenized question will + # take a lots of space). So we remove that left whitespace + examples[q_col_name] = [ + q.lstrip() for q in examples[q_col_name] + ] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows + # using a stride. This results in one example possible giving several features + # when a context is long, each of those features having a context that overlaps + # a bit the context of the previous feature. + tokenized_examples = self.tokenizer( + examples[q_col_name if pad_on_right else c_col_name], + examples[c_col_name if pad_on_right else q_col_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_prompt_length, + stride=min(args.doc_stride, max_prompt_length // 2), + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we + # need a map from a feature to its corresponding example. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the + # context, so we keep the corresponding example_id and we will store the offset + # mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the + # context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example + # containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so + # it's easy to determine if a token position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + def convert_batch_to_fms_style( + self, + batch: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """FMS uses a different standard than HF for encoder inputs.""" + + return {'x': batch['input_ids'], 'mask': batch['attention_mask']} + + def process_eval_set(self) -> None: + """Pre-process evaluation dataset for QuestionAnswering task.""" + + if not has_hf: + raise ImportError( + "QuestionAnswering Encoder requires transformer package but import " + "was unsuccessful." + ) + + args = self.args + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub + raw_datasets = load_dataset( + args.dataset_name, + args.dataset_config_name, + trust_remote_code=False, + ) + else: + data_files = {} + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.validation_file.split(".")[-1] + else: + raise ValueError( + "Could not determine evaluation dataset to load. Pass `dataset_name` " + "or `validation_file` argument." + ) + raw_datasets = load_dataset(extension, data_files=data_files, field="data") + + column_names = raw_datasets["train"].column_names + + self.question_column_name = "question" if "question" in column_names else column_names[0] + self.context_column_name = "context" if "context" in column_names else column_names[1] + self.answer_column_name = "answers" if "answers" in column_names else column_names[2] + + # Padding side determines if we do (question|context) or (context|question) + self.pad_on_right = self.tokenizer.padding_side == "right" + + model_max_length = self.tokenizer.model_max_length + if args.max_prompt_length > model_max_length: + dprint( + f"max_prompt_length ({args.max_prompt_length}) is larger than the " + f"maximum length supported ({model_max_length}). " + f"Using max_prompt_length={model_max_length} instead." + ) + self.max_prompt_length = min( + args.max_seq_length, + model_max_length, + ) + + eval_examples = raw_datasets["validation"] + if args.max_eval_samples is not None: + # We will select sample from whole data + eval_examples = eval_examples.select(range(args.max_eval_samples)) + self.eval_examples = eval_examples + + eval_dataset = eval_examples.map( + self.prepare_validation_features, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + + if args.max_eval_samples is not None: + # During Feature creation dataset samples might increase, we will select + # required samples again + eval_dataset = eval_dataset.select(range(args.max_eval_samples)) + + # store evaluation dataset prior dropping + self.eval_dataset = eval_dataset + + # DataLoaders creation: + if args.pad_to_max_length: + # If padding was already done ot max length, we use the default data collator + # that will just convert everything to tensors. + self.data_collator = default_data_collator + else: + # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us + # (by padding to the maximum length of the samples passed). + pad_to_multiple_of = None + self.data_collator = DataCollatorWithPadding( + self.tokenizer.tokenizer, + pad_to_multiple_of=pad_to_multiple_of, + ) + + self.eval_dataset_for_model = eval_dataset.remove_columns( + ["example_id", "offset_mapping"] + ) + self.eval_dataloader = DataLoader( + self.eval_dataset_for_model, + shuffle=False, + collate_fn=self.data_collator, + batch_size=args.batch_size, + ) + dprint("Dataloader initialized.") + + self.metric = evaluate.load( + "squad_v2" if args.version_2_with_negative else "squad" + ) + dprint("Evaluation metric initialized.") + + def postprocess_qa_predictions( + self, + examples: Dataset, + features: Dataset, + predictions: tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: str | None = None, + prefix: str | None = None, + ) -> None: + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + + if len(predictions) != 2: + raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") + all_start_logits, all_end_logits = predictions + + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + # Logging. + dprint(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 + or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 + ): + continue + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + if version_2_with_negative and min_null_prediction is not None: + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if ( + version_2_with_negative + and min_null_prediction is not None + and not any(p["offsets"] == (0, 0) for p in predictions) + ): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + # Compute the softmax of all scores + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + dprint(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + dprint(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + dprint(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + def post_processing_function( + self, + examples: Dataset, + features: Dataset, + predictions: list[np.ndarray], + stage: str = "eval", + ) -> dict[list[str, str]]: + """Post-processing: we match the start logits and end logits to answers in + the original context.""" + + args = self.args + predictions = self.postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=args.version_2_with_negative, + n_best_size=args.n_best_size, + max_answer_length=args.max_answer_length, + null_score_diff_threshold=args.null_score_diff_threshold, + output_dir=None, + prefix=stage, + ) + + # Format the result to the format the metric expects. + if args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} + for k, v in predictions.items() + ] + else: + formatted_predictions = [ + {"id": k, "prediction_text": v} for k, v in predictions.items() + ] + + references = [ + {"id": ex["id"], "answers": ex[self.answer_column_name]} for ex in examples + ] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + def create_and_fill_np_array( + self, + start_or_end_logits: list[np.ndarray], + dataset: Dataset, + max_len: int, + ) -> np.ndarray: + """ + Create and fill numpy array of size + len_of_validation_data * max_length_of_output_tensor + + Args: + start_or_end_logits(:obj:`tensor`): + This is the output predictions of the model. We can only enter either + start or end logits. + eval_dataset: Evaluation dataset + max_len(:obj:`int`): + The maximum length of the output tensor. ( See the model.eval() part + for more details ) + """ + + step = 0 + # create a numpy array and fill it with -100. + logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) + # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather_for_metrics + for i, output_logit in enumerate(start_or_end_logits): # populate columns + # We have to fill it such that we have to take the whole tensor and replace it on the newly created array + # And after every iteration we have to change the step + + batch_size = output_logit.shape[0] + cols = output_logit.shape[1] + + if step + batch_size < len(dataset): + logits_concat[step : step + batch_size, :cols] = output_logit + else: + logits_concat[step:, :cols] = output_logit[: len(dataset) - step] + + step += batch_size + + return logits_concat + + def run_warmup(self) -> None: + """Run warmup cycle of compiled encoder model set for QuestionAnswering task.""" + + dprint(f"Starting warm-up...") + warmup_start_time = time.time() + dataloader_for_compile = DataLoader( + self.eval_dataset_for_model, + shuffle=False, + collate_fn=self.data_collator, + batch_size=1, + ) + first_batch = self.convert_batch_to_fms_style(next(iter(dataloader_for_compile))) + self.model(**first_batch) + if rank == 0: + dprint(f"Warmup completed in {time.time() - warmup_start_time:.1f} s\n---") + + def run_evaluation(self) -> None: + """Run QuestionAnswering evaluation.""" + + args = self.args + eval_dataloader = self.eval_dataloader + + if rank == 0: + dprint(f"Running evaluation ({len(eval_dataloader)} samples)...") + start_time = time.time() + + all_start_logits = [] + all_end_logits = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + dprint(f"Step {step + 1} / {len(eval_dataloader)}") + batch = self.convert_batch_to_fms_style(batch) + start_logits, end_logits = self.model(**batch) + all_start_logits.append(start_logits.cpu().numpy()) + all_end_logits.append(end_logits.cpu().numpy()) + eval_duration = time.time() - start_time + dprint( + f"Runtime: {eval_duration:.0f} s | " + f"{eval_duration / len(eval_dataloader):.2f} s/batch | " + f"{eval_duration / (len(eval_dataloader) * args.batch_size):.2f}" + " s/sample " + f"(tot = {len(eval_dataloader) * args.batch_size}, " + f"bs = {args.batch_size})" + ) + + # concatenate the numpy array + max_len = max([x.shape[1] for x in all_start_logits]) + start_logits_concat = self.create_and_fill_np_array( + all_start_logits, + self.eval_dataset, + max_len, + ) + end_logits_concat = self.create_and_fill_np_array( + all_end_logits, + self.eval_dataset, + max_len, + ) + + del all_start_logits + del all_end_logits + + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = self.post_processing_function( + self.eval_examples, + self.eval_dataset, + outputs_numpy, + ) + eval_metric = self.metric.compute( + predictions=prediction.predictions, + references=prediction.label_ids, + ) + dprint(f"Evaluation metrics: {eval_metric}") + + +class EncoderMLMInfer(): + """Run MaskedLM task with encoder models.""" + + def __init__( + self, + model: HFModelArchitecture, + tokenizer: BaseTokenizer, + args: argparse.Namespace, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.args = args + + + def process_eval_set(self) -> None: + """Barebone function that sets up a single example prompt (for now).""" + + if not has_hf: + raise ImportError( + "MaskedLM Encoder requires transformer package but import " + "was unsuccessful." + ) + + self.prompt = "the dog chased the cat while aggressively" + + def run_evaluation(self, warmup: bool = False) -> None: + """Run evaluation cycle of compiled encoder model set for MaskedLM task. + No output printout if warmup is True. + """ + + dprint(f"Starting evaluation ({warmup=})...") + warmup_start_time = time.time() + unmasker = pipeline( + "fill-mask", + model=self.model, + tokenizer=self.tokenizer.tokenizer, + ) + output = unmasker(self.prompt) + if rank == 0: + dprint(f"Run completed in {time.time() - warmup_start_time:.1f} s\n---") + if not warmup: + dprint(f"{self.prompt}\nAnswers:") + for ans in output: + dprint(f"{ans['token_str']:10} | {ans['score']:6.4f}") + + +def run_encoder_eval_qa( + model: nn.Module, # FMS-style model + tokenizer: BaseTokenizer, + args: argparse.Namespace, +) -> None: + """Entry point to run QuestionAnswering Evaluation of encoder model. + + Processing based on pytorch example: + https://github.com/huggingface/transformers/blob/main/examples/pytorch/... + ...question-answering/run_qa_no_trainer.py + """ + + encoder_qa_infer = EncoderQAInfer(model, tokenizer, args) + encoder_qa_infer.process_eval_set() + if args.compile: + encoder_qa_infer.run_warmup() + encoder_qa_infer.run_evaluation() + + +def run_encoder_eval_mlm( + model: HFModelArchitecture, # model wrapped by to_hf_api + tokenizer: BaseTokenizer, + args: argparse.Namespace, +) -> None: + """Entry point to run evaluation of encoder models.""" + + encoder_mlm_infer = EncoderMLMInfer(model, tokenizer, args) + encoder_mlm_infer.process_eval_set() + if args.compile: + encoder_mlm_infer.run_evaluation(warmup=True) + encoder_mlm_infer.run_evaluation() diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py new file mode 100644 index 0000000..3816da8 --- /dev/null +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -0,0 +1,135 @@ +# Standard +import argparse +import os +import sys + +# Third party +import numpy as np +import random +import torch +from torch import nn, distributed + +# Local +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size +from aiu_fms_testing_utils.utils import aiu_setup + + +def get_default_dtype(args: argparse.Namespace) -> torch.dtype | None: + """Return default_dtype for non-quantized models, otherwise None. + If default_dtype is provided, it is set as torch default for non-quantized models. + """ + + default_dtype = None + if not args.is_quantized: + dtypes_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + if args.default_dtype is not None: + default_dtype = dtypes_map[args.default_dtype] + if default_dtype is not None: + torch.set_default_dtype(default_dtype) + return default_dtype + + +def get_device(args: argparse.Namespace) -> torch.device: + """Return torch device and, if needed, set up AIU and its env variables. + NOTE: args.device_type is str, but this function returns torch.device. + """ + + if args.device_type == "cuda": + device = torch.device(args.device_type, local_rank) + torch.cuda.set_device(device) + elif args.is_aiu_backend: + from torch_sendnn import torch_sendnn + + if args.distributed: + aiu_setup.aiu_dist_setup( + distributed.get_rank(), + distributed.get_world_size(), + ) + else: + aiu_setup.aiu_setup(rank, world_size) + aiu_setup.set_aiu_env_vars(args) + device = torch.device("cpu") + else: + device = torch.device(args.device_type) + return device + + +def print_system_setup(args: argparse.Namespace) -> None: + """Display system info (rank 0 only).""" + + if rank == 0 and args.verbose: + dprint("-"*60) + dprint( + f"Python Version : {sys.version_info.major}." + f"{sys.version_info.minor}.{sys.version_info.micro}" + ) + dprint(f"PyTorch Version : {torch.__version__}") + dprint(f"Dynamo Backend : {args.device_type} -> {args.dynamo_backend}") + dprint(f"Distributed : {args.distributed}") + if args.device_type == 'aiu': + for peer_rank in range(aiu_setup.world_size): + pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) + dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") + print("-"*60) + + +def set_determinism(args: argparse.Namespace) -> None: + """Set determinism. + NOTE: torch determinism requires env variable: `CUBLAS_WORKSPACE_CONFIG=:4096:8` + """ + + if args.deterministic: + random.seed(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.use_deterministic_algorithms(True) + + +def get_distributed_strategy(args: argparse.Namespace) -> str | None: + """Return distributed strategy.""" + + if args.distributed: + dist_strat = "tp" + else: + if torch.cuda.device_count() > 1 and world_size == 1: + dist_strat = "mp" + else: + dist_strat = None + return dist_strat + + +def setup_model(args: argparse.Namespace) -> tuple[str | None, torch.device, str]: + """Entry point for model setup.""" + + default_dtype = get_default_dtype(args) + device = get_device(args) + print_system_setup(args) + set_determinism(args) + dist_strat = get_distributed_strategy(args) + + return default_dtype, device, dist_strat + + +def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: + """Printout model and list of model parameters with related statistics.""" + + if rank == 0 and args.verbose: + dprint("="*60 + "\n") + dprint("\n".join( + f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} " + f"{v.min().item():12.4f} {v.max().item():12.4f}" + for k,v in model.state_dict().items() + )) + dprint("="*60 + "\n") + if args.architecture == "llama": + dprint( + "[NOTE] In Llama models, it's OK for bias and rotary embeddings to be " + "marked as unused keys because of different architectural choices between " + "FMS and HF models (but model output is preserved)." + ) + dprint(model) + dprint("="*60 + "\n") diff --git a/aiu_fms_testing_utils/utils/quantization_setup.py b/aiu_fms_testing_utils/utils/quantization_setup.py new file mode 100644 index 0000000..264d5f2 --- /dev/null +++ b/aiu_fms_testing_utils/utils/quantization_setup.py @@ -0,0 +1,126 @@ +# Standard +from functools import partial +from typing import Any +import argparse +import json +import os + +# Third Party +from torch import nn + +# Local Packages +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank + + +def import_addons(args: argparse.Namespace) -> None: + """Import addons from FMS-MO. The import operation will register the selected + quantization addon (comprising adapter, linear module, and custom op) with FMS. + """ + + try: + if args.quantization == "gptq" and "aiu" in args.device_type: + from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear + elif args.quantization == "int8": + from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear + dprint("Loaded `aiu_addons` functionalities") + except: + raise ImportError(f"Failed to import {args.quantization} addons from FMS-MO.") + + +def get_linear_config(args: argparse.Namespace) -> dict[str, Any]: + """Return a linear_config dictionary to be used to instantiate quantized modules + by FMS get_model + """ + + fused_weights = not args.unfuse_weights + if args.quantization == "gptq": + if fused_weights and args.is_aiu_backend: + raise ValueError( + "GPTQ checkpoints on AIU must always run with --unfuse_weights" + ) + if args.default_dtype is not None: + raise ValueError( + "GPTQ default_dtype must be None to preserve the checkpoint data types." + ) + + if "aiu" in args.device_type: + linear_type = "gptq_aiu" + elif args.device_type == "cpu": + linear_type = "gptq_cpu" + elif args.device_type == "cuda": + linear_type = "gptq" # GPTQ support on GPU is FMS-native + else: + raise ValueError(f"Unsupported device {args.device} for GPTQ") + + qconfig_path = args.model_path + "/quantize_config.json" + if os.path.exists(qconfig_path): + with open(qconfig_path, 'r') as f: + dprint(f"loading quantization config from {qconfig_path}") + qconfig = json.load(f) + group_size = qconfig["group_size"] + desc_act = qconfig["desc_act"] + if desc_act: + raise NotImplementedError( + "Activation reordering not supported at this time." + ) + else: + dprint( + "[WARNING] Could not locate quantization config file. " + "Default configuration will be used." + ) + group_size = 128 + desc_act = False + + linear_config = { + "linear_type": linear_type, + "group_size": group_size, + "desc_act": desc_act, + } + elif args.quantization == "int8": + if fused_weights and args.is_aiu_backend: + raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") + if args.default_dtype is not None: + raise ValueError( + "INT8 default_dtype must be None to preserve the checkpoint data types." + ) + + def select_int8_module( + module_name: str | None = None, + smoothquant: bool = True, + smoothquant_layers: list[str] | None = None, + ): + if module_name is None: + return "int8_aiu" + smoothquant_on_module = ( + any([m in module_name for m in smoothquant_layers]) + if smoothquant_layers is not None + else True + ) + use_smoothquant = smoothquant and smoothquant_on_module + return "int8_smoothquant_aiu" if use_smoothquant else "int8_aiu" + + if args.int8_smoothquant: + # TODO: load info from config saved during quantization + if any("granite" in p.lower() for p in [args.model_path, args.architecture]): + smoothquant_layers = ["key", "value", "w1", "wg"] + elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]): + smoothquant_layers = ["query", "key", "value", "w1"] + else: + raise NotImplementedError( + "INT8 architecture does not support smoothquant." + ) + else: + smoothquant_layers = [] + + linear_config = { + "linear_type": partial( + select_int8_module, + smoothquant = args.int8_smoothquant, + smoothquant_layers = smoothquant_layers, + ), + "weight_per_channel": args.int8_weight_per_channel, + "activ_quant_type": args.int8_activ_quant_type, + } + else: + linear_config = {"linear_type": "torch_linear"} + return linear_config diff --git a/scripts/encoders_inference.py b/scripts/encoders_inference.py new file mode 100644 index 0000000..8001468 --- /dev/null +++ b/scripts/encoders_inference.py @@ -0,0 +1,98 @@ +# Standard +import argparse +import time + +# Third Party +from fms.models import get_model +from fms.utils import tokenizers +from torch import distributed, set_grad_enabled + +# Local Packages +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size +from aiu_fms_testing_utils.utils.args_parsing import get_args +from aiu_fms_testing_utils.utils.encoders_utils import ( + wrap_encoder, + run_encoder_eval_qa, + run_encoder_eval_mlm, +) +from aiu_fms_testing_utils.utils.model_setup import setup_model, print_model_params +from aiu_fms_testing_utils.utils.quantization_setup import ( + import_addons, + get_linear_config, +) + + +parser = argparse.ArgumentParser( + description="Entry point for AIU inference of encoder models." +) +args = get_args(parser) +args.is_encoder = True # add argument directly into Namespace + +if args.is_quantized: + import_addons(args) + +if args.distributed: + distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size) + +# Main model setup +default_dtype, device, dist_strat = setup_model(args) + +# Retrieve linear configuration (quantized or not) to instantiate FMS model +linear_config = get_linear_config(args) + +if rank == 0: + dprint("="*60) + dprint(f"model_path={args.model_path}") + dprint(f"{linear_config=}") + dprint(f"fused_weights={args.fused_weights}") + dprint(f"data_type={default_dtype}") + dprint("="*60 + "\n") + +dprint("Loading model...") +loading_model_start = time.time() +model = get_model( + args.architecture, + args.variant, + model_path=args.model_path, + device_type="cpu" if args.is_aiu_backend else args.device_type, + data_type=default_dtype, + source=args.model_source, + distributed_strategy=dist_strat, + group=distributed.group.WORLD, + linear_config=linear_config, + fused_weights=args.fused_weights, +) + +if args.is_quantized: + print_model_params(model, args) + +tokenizer = tokenizers.get_tokenizer(args.tokenizer) + +model.eval() +set_grad_enabled(False) +if args.distributed: + distributed.barrier() +dprint(f"Loading model completed in {time.time() - loading_model_start:.2f} s.") + +if args.architecture == "roberta": + model = wrap_encoder(model) + +if args.compile: + dprint("Compiling model...") + if args.is_aiu_backend: + model.compile(backend="sendnn_decoder") + else: + # compiling can make first inference pass slow + model.compile(mode=args.compile_mode, backend=args.compile_backend) + dprint("Model compiled.") +else: + dprint("Skip model compiling. Only for debug purpose.") + +if args.architecture == "roberta_question_answering": + run_encoder_eval_qa(model, tokenizer, args) +elif args.architecture == "roberta": # basic MaskedLM downstream task + run_encoder_eval_mlm(model, tokenizer, args) + +if args.distributed: + distributed.barrier() + distributed.destroy_process_group() From 6de2a311d7edc9c89a258aec659dee11a98c7ec7 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Mon, 30 Jun 2025 18:39:07 -0400 Subject: [PATCH 02/22] fp8 encoder support Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/__init__.py | 90 +++++++++++++------ aiu_fms_testing_utils/utils/args_parsing.py | 2 +- aiu_fms_testing_utils/utils/encoders_utils.py | 32 +++---- aiu_fms_testing_utils/utils/model_setup.py | 2 +- .../utils/quantization_setup.py | 5 ++ scripts/encoders_inference.py | 14 ++- 6 files changed, 94 insertions(+), 51 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 99cac86..eed35d8 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -1,15 +1,27 @@ -import torch -import torch.nn as nn -import time -from fms.utils.tokenizers import BaseTokenizer -from aiu_fms_testing_utils.utils.aiu_setup import dprint +# STandard from typing import Optional, List, Tuple -import os -import requests import json +import os import random +import requests +import time + +# Third Party +from aiu_fms_testing_utils.utils.aiu_setup import dprint +from fms.utils.tokenizers import BaseTokenizer +import torch +import torch.nn as nn -def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, attn_type="sdpa", **padding_kwargs): + +def warmup_model( + model: nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + compile_dynamic_sendnn: bool = False, + attn_type="sdpa", + use_cache: bool = True, + **padding_kwargs +): import torch_sendnn attention_specific_kwargs = {} if attn_type == "paged": @@ -18,7 +30,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, # TODO: Add a unified generation dependent on attn_type from fms.utils.generation import generate attention_specific_kwargs["contiguous_cache"] = True - + dprint("AIU warmup") pt_compile_model_time = time.time() @@ -30,12 +42,23 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, _max_new_tokens = 2 # always warmup with batch size 2 when using attn_type=paged if attn_type == "paged": - _warmup_input_ids, _padding_kwargs = adjust_inputs_to_batch(input_ids, **padding_kwargs) + _warmup_input_ids, _padding_kwargs = adjust_inputs_to_batch( + input_ids, + **padding_kwargs, + ) extra_kwargs = {**_padding_kwargs, "only_last_token": attn_type != "paged"} with torch_sendnn.warmup_mode(): - generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, use_cache=True, do_sample=False, extra_kwargs=extra_kwargs, **attention_specific_kwargs) + generate( + model, + _warmup_input_ids, + max_new_tokens=_max_new_tokens, + use_cache=use_cache, + do_sample=False, + extra_kwargs=extra_kwargs, + **attention_specific_kwargs, + ) pt_compile_model_time = time.time() - pt_compile_model_time dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") @@ -51,17 +74,17 @@ def __download_file(url, filename): try: response = requests.get(url, stream=True) response.raise_for_status() - + with open(filename, 'wb') as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"Successfully downloaded {filename}") - + except requests.exceptions.RequestException as e: print(f"An error occurred: {e}") def __sample_requests( - prompt_list: List[str], + prompt_list: List[str], num_requests: int, tokenizer: BaseTokenizer, prompt_length_min: int = 32, @@ -81,16 +104,14 @@ def __sample_requests( # Tokenize the prompts and completions. prompt = prompt_list[i] prompt_token_ids = ids_for_prompt(prompt, tokenizer) - + prompt_len = len(prompt_token_ids) if prompt_len < prompt_length_min or prompt_len > prompt_length_max: # Prune too short or too long sequences. continue filtered_dataset.append((prompt, prompt_len)) - - return filtered_dataset - + return filtered_dataset def sample_sharegpt_requests( dataset_path: str, @@ -110,15 +131,22 @@ def sample_sharegpt_requests( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data["conversations"][0]["value"] for data in dataset] - - return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed) + + return __sample_requests( + dataset, + num_requests, + tokenizer, + prompt_length_min, + prompt_length_max, + seed, + ) def sample_squad_v2_qa_requests( dataset_path: str, - num_requests: int, - tokenizer: BaseTokenizer, - prompt_length_min: int = 32, - prompt_length_max: int = 64, + num_requests: int, + tokenizer: BaseTokenizer, + prompt_length_min: int = 32, + prompt_length_max: int = 64, seed: Optional[int] = None ) -> List[Tuple[str, int]]: from datasets import load_dataset @@ -127,10 +155,14 @@ def sample_squad_v2_qa_requests( ds = load_dataset(dataset_path)['train'] else: ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train'] - - - ds = [f"{data['context']}\n{data['question']}" for data in ds] - return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed) - + ds = [f"{data['context']}\n{data['question']}" for data in ds] + return __sample_requests( + ds, + num_requests, + tokenizer, + prompt_length_min, + prompt_length_max, + seed, + ) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 59278a5..87de21b 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -145,7 +145,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: help="Use dynamic shapes with aiu compile", ) - # Arguments shared between Decoder and Encoder models + # Arguments shared between Decoder (future support) and Encoder models args_dec_enc = parser.add_argument_group("Decoders or Encoders (shared args)") args_dec_enc.add_argument( "--batch_size", diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 52c700d..1ff7d41 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -75,22 +75,22 @@ def validate_encoder_arguments(self) -> None: "Running encoder model but is_encoder argument is not set to True. " "Verify your launch script." ) - if args.min_pad_length != 0: - raise ValueError( - "Argument min_pad_length should not be provided to encoders. " - "To pad the input sequence, use --pad_to_max_length flag instead." - ) - if args.fixed_prompt_length != 0: - raise ValueError( - "Argument fixed_prompt_length should not be provided to encoders. " - "To pad the input sequence, use --pad_to_max_length flag instead." - ) - if args.max_new_tokens != 100: # default value for decoder models - raise ValueError( - "Argument max_new_token should not be provided to encoders. " - "To define the max length of a generated answer in QuestionAnswering " - "use --max_answer_length instead." - ) + # if args.min_pad_length != 0: + # raise ValueError( + # "Argument min_pad_length should not be provided to encoders. " + # "To pad the input sequence, use --pad_to_max_length flag instead." + # ) + # if args.fixed_prompt_length != 0: + # raise ValueError( + # "Argument fixed_prompt_length should not be provided to encoders. " + # "To pad the input sequence, use --pad_to_max_length flag instead." + # ) + # if args.max_new_tokens != 100: # default value for decoder models + # raise ValueError( + # "Argument max_new_token should not be provided to encoders. " + # "To define the max length of a generated answer in QuestionAnswering " + # "use --max_answer_length instead." + # ) def prepare_validation_features( diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 3816da8..8659f53 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -121,7 +121,7 @@ def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: dprint("="*60 + "\n") dprint("\n".join( f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} " - f"{v.min().item():12.4f} {v.max().item():12.4f}" + f"{v.float().min().item():12.4f} {v.float().max().item():12.4f}" for k,v in model.state_dict().items() )) dprint("="*60 + "\n") diff --git a/aiu_fms_testing_utils/utils/quantization_setup.py b/aiu_fms_testing_utils/utils/quantization_setup.py index 264d5f2..8fb70af 100644 --- a/aiu_fms_testing_utils/utils/quantization_setup.py +++ b/aiu_fms_testing_utils/utils/quantization_setup.py @@ -20,6 +20,8 @@ def import_addons(args: argparse.Namespace) -> None: try: if args.quantization == "gptq" and "aiu" in args.device_type: from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear + elif args.quantization == "fp8": + from fms_mo.aiu_addons.fp8 import fp8_adapter, fp8_attn, fp8_linear elif args.quantization == "int8": from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear dprint("Loaded `aiu_addons` functionalities") @@ -76,6 +78,9 @@ def get_linear_config(args: argparse.Namespace) -> dict[str, Any]: "group_size": group_size, "desc_act": desc_act, } + elif args.quantization == "fp8": + dprint("fp8 config is inferred from HF checkpoint via FMS / FMS-MO functions") + return None elif args.quantization == "int8": if fused_weights and args.is_aiu_backend: raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") diff --git a/scripts/encoders_inference.py b/scripts/encoders_inference.py index 8001468..4906c23 100644 --- a/scripts/encoders_inference.py +++ b/scripts/encoders_inference.py @@ -4,6 +4,7 @@ # Third Party from fms.models import get_model +from fms.models.roberta import RoBERTaForQuestionAnswering, RoBERTa from fms.utils import tokenizers from torch import distributed, set_grad_enabled @@ -21,7 +22,6 @@ get_linear_config, ) - parser = argparse.ArgumentParser( description="Entry point for AIU inference of encoder models." ) @@ -75,7 +75,7 @@ dprint(f"Loading model completed in {time.time() - loading_model_start:.2f} s.") if args.architecture == "roberta": - model = wrap_encoder(model) + model = wrap_encoder(model) # enable using pipeline to eval RoBERTa MaskedLM if args.compile: dprint("Compiling model...") @@ -88,9 +88,15 @@ else: dprint("Skip model compiling. Only for debug purpose.") -if args.architecture == "roberta_question_answering": +if ( + args.architecture == "roberta_question_answering" + or isinstance(model, RoBERTaForQuestionAnswering) +): run_encoder_eval_qa(model, tokenizer, args) -elif args.architecture == "roberta": # basic MaskedLM downstream task +elif ( + args.architecture == "roberta" + or isinstance(model, RoBERTa) +): # basic MaskedLM downstream task run_encoder_eval_mlm(model, tokenizer, args) if args.distributed: From 0c6a36b96c3670fb6a5d8ec9286aeaa1e2a1d620 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Mon, 30 Jun 2025 18:41:08 -0400 Subject: [PATCH 03/22] Update detection of RoBERTa architecture Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/__init__.py | 2 +- scripts/encoders_inference.py | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index eed35d8..2199641 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -1,4 +1,4 @@ -# STandard +# Standard from typing import Optional, List, Tuple import json import os diff --git a/scripts/encoders_inference.py b/scripts/encoders_inference.py index 4906c23..6f3c423 100644 --- a/scripts/encoders_inference.py +++ b/scripts/encoders_inference.py @@ -74,7 +74,7 @@ distributed.barrier() dprint(f"Loading model completed in {time.time() - loading_model_start:.2f} s.") -if args.architecture == "roberta": +if isinstance(model, RoBERTa): model = wrap_encoder(model) # enable using pipeline to eval RoBERTa MaskedLM if args.compile: @@ -88,15 +88,9 @@ else: dprint("Skip model compiling. Only for debug purpose.") -if ( - args.architecture == "roberta_question_answering" - or isinstance(model, RoBERTaForQuestionAnswering) -): +if isinstance(model, RoBERTaForQuestionAnswering): run_encoder_eval_qa(model, tokenizer, args) -elif ( - args.architecture == "roberta" - or isinstance(model, RoBERTa) -): # basic MaskedLM downstream task +elif isinstance(model, RoBERTa): # basic MaskedLM downstream task run_encoder_eval_mlm(model, tokenizer, args) if args.distributed: From abc01e5aaac74d3b0d7c127c3d7a604090ce4fae Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:11:11 -0400 Subject: [PATCH 04/22] Remove commented decoder args Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 64 --------------------- 1 file changed, 64 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 87de21b..cb16903 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -164,70 +164,6 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: ), ) - # Decoder model arguments - # args_decoder = parser.add_argument_group("Decoders") - # args_decoder.add_argument( - # "--min_pad_length", - # type=int, - # default=0, - # help=( - # "Pad inputs to a minimum specified length. If any prompt is larger than " - # "the specified length, padding will be determined by the largest prompt" - # ), - # ) - # args_decoder.add_argument( - # "--fixed_prompt_length", - # type=int, - # default=0, - # help=( - # "If defined, overrides both min_pad_length and max_prompt_length. " - # "Pads input to fixed_prompt_length, fails if any input needs truncation." - # ), - # ) - # args_decoder.add_argument( - # "--max_new_tokens", - # type=int, - # help="max number of generated tokens", - # default=100, - # ) - # args_decoder.add_argument( - # "--no_early_termination", - # action="store_true", - # help="disable early termination on generation", - # ) - # args_decoder.add_argument( - # "--prompt_type", - # type=str, - # choices=["chat", "code"], - # default="chat", - # help="type of prompts to be used, either chat or code", - # ) - # args_decoder.add_argument( - # "--prompt_path", - # type=str, - # default="", - # help=( - # "If set, load the prompts from file(s) instead of the local examples. " - # "Supports glob-style patterns" - # ), - # ) - # args_decoder.add_argument( - # "--timing", - # type=str, - # choices=["e2e", "per-token"], - # default="", - # help="if set, how to time the generation of tokens, e2e or per-token", - # ) - # args_decoder.add_argument( - # "--iters", - # type=int, - # default=1, - # help=( - # "Number of iterations of inference to perform. Used for variance " - # "performance capture." - # ), - # ) - # Encoder model arguments args_encoder = parser.add_argument_group("Encoders") args_encoder.add_argument( From 0c84b8be178a215c981566ec4617dd0bbbc906f6 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:17:25 -0400 Subject: [PATCH 05/22] Make verbose a flag Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index cb16903..07d5fe5 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -106,11 +106,11 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: action="store_true", help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", ) - args_run_settings.add_argument( # could be a bool / flag + args_run_settings.add_argument( '-v', '--verbose', - action='count', + action='store_true', default=0, - help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" + help="Enable verbose output" ) # Arguments for compilation From 66780e3d0847068629e5b4e7b159de2dd47e3b49 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:24:37 -0400 Subject: [PATCH 06/22] Remove TODO in FP8 quantization argument Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 07d5fe5..0b8d9d1 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -56,7 +56,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: args_quantization.add_argument( "--quantization", type=str, - choices=["gptq", "fp8"], # TODO: add "fp8" when available in FMS + choices=["gptq", "fp8"], default=None, help="Type of quantization of the model checkpoint", ) From bde2c605134f48f70883f3d3e4b94c35309eb62c Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:29:58 -0400 Subject: [PATCH 07/22] Remove decoder arguments from argument validation Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/encoders_utils.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 1ff7d41..c29420e 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -67,7 +67,11 @@ def __init__( self.validate_encoder_arguments() def validate_encoder_arguments(self) -> None: - """Ensure arguments compatibility with Encoder models.""" + """Ensure arguments compatibility with Encoder models. + + NOTE: when Decoder models are refactored, this function will be expanded to + ensure decoder arguments are not being provided to the encoder script. + """ args = self.args if not getattr(args, "is_encoder", False): @@ -75,23 +79,6 @@ def validate_encoder_arguments(self) -> None: "Running encoder model but is_encoder argument is not set to True. " "Verify your launch script." ) - # if args.min_pad_length != 0: - # raise ValueError( - # "Argument min_pad_length should not be provided to encoders. " - # "To pad the input sequence, use --pad_to_max_length flag instead." - # ) - # if args.fixed_prompt_length != 0: - # raise ValueError( - # "Argument fixed_prompt_length should not be provided to encoders. " - # "To pad the input sequence, use --pad_to_max_length flag instead." - # ) - # if args.max_new_tokens != 100: # default value for decoder models - # raise ValueError( - # "Argument max_new_token should not be provided to encoders. " - # "To define the max length of a generated answer in QuestionAnswering " - # "use --max_answer_length instead." - # ) - def prepare_validation_features( self, From e8c2cd57da94ca8c5e5e8372ce780c65260ef92a Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:32:23 -0400 Subject: [PATCH 08/22] Update argument help Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 0b8d9d1..1ad521f 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -234,12 +234,15 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: ), ) args_encoder.add_argument( - "--preprocessing_num_workers", type=int, default=1, help="" + "--preprocessing_num_workers", + type=int, + default=1, + help="Number of workers used during preprocessing of validation set (QA only).", ) args_encoder.add_argument( "--overwrite_cache", action="store_true", - help="Overwrite the cached training and evaluation sets", + help="Overwrite the cached training and evaluation sets.", ) args_encoder.add_argument( "--doc_stride", From e66cd1c239f658fdd728a516d3d71cbb704f2695 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:41:53 -0400 Subject: [PATCH 09/22] Update padding explanation Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 2 +- aiu_fms_testing_utils/utils/encoders_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 1ad521f..c2f3103 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -221,7 +221,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: action="store_true", help=( "If passed, pad all samples to `max_seq_length`. " - "Otherwise, dynamic padding is used." + "Otherwise, pad each batch individually to the longest sequence." ), ) args_encoder.add_argument( diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index c29420e..b016836 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -235,8 +235,8 @@ def process_eval_set(self) -> None: # that will just convert everything to tensors. self.data_collator = default_data_collator else: - # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us - # (by padding to the maximum length of the samples passed). + # Otherwise, `DataCollatorWithPadding` will pad to the maximum length + # of the samples passed. pad_to_multiple_of = None self.data_collator = DataCollatorWithPadding( self.tokenizer.tokenizer, From e5555c8e708cdc2e19a2b149e459bf38736a2dba Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 17:52:27 -0400 Subject: [PATCH 10/22] Update linear config message for FP8 Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/quantization_setup.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/utils/quantization_setup.py b/aiu_fms_testing_utils/utils/quantization_setup.py index 8fb70af..c7eac38 100644 --- a/aiu_fms_testing_utils/utils/quantization_setup.py +++ b/aiu_fms_testing_utils/utils/quantization_setup.py @@ -79,11 +79,15 @@ def get_linear_config(args: argparse.Namespace) -> dict[str, Any]: "desc_act": desc_act, } elif args.quantization == "fp8": - dprint("fp8 config is inferred from HF checkpoint via FMS / FMS-MO functions") + dprint( + "[INFO] fp8 config is inferred from HF checkpoint via FMS / FMS-MO functions" + ) return None elif args.quantization == "int8": if fused_weights and args.is_aiu_backend: - raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") + raise ValueError( + "INT8 checkpoints on AIU must always run with --unfuse_weights" + ) if args.default_dtype is not None: raise ValueError( "INT8 default_dtype must be None to preserve the checkpoint data types." From af3d681cca985b11804f9fa71a60ea6d93bbe27a Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:01:48 -0400 Subject: [PATCH 11/22] raise error for default_dtype + quantization Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/model_setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 8659f53..a839e47 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -30,6 +30,11 @@ def get_default_dtype(args: argparse.Namespace) -> torch.dtype | None: default_dtype = dtypes_map[args.default_dtype] if default_dtype is not None: torch.set_default_dtype(default_dtype) + elif args.default_dtype is not None: + raise ValueError( + f"default_dtype (currently set to {args.default_dtype}) must be unset " + "when running a quantized model." + ) return default_dtype From d8b73ee56c5f3c11a326edf1a0fdd6e2f1ecbb62 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:03:43 -0400 Subject: [PATCH 12/22] Update printouts Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/model_setup.py | 6 +++--- scripts/encoders_inference.py | 13 ++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index a839e47..705362a 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -66,7 +66,7 @@ def get_device(args: argparse.Namespace) -> torch.device: def print_system_setup(args: argparse.Namespace) -> None: """Display system info (rank 0 only).""" - if rank == 0 and args.verbose: + if args.verbose: dprint("-"*60) dprint( f"Python Version : {sys.version_info.major}." @@ -79,7 +79,7 @@ def print_system_setup(args: argparse.Namespace) -> None: for peer_rank in range(aiu_setup.world_size): pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") - print("-"*60) + dprint("-"*60) def set_determinism(args: argparse.Namespace) -> None: @@ -122,7 +122,7 @@ def setup_model(args: argparse.Namespace) -> tuple[str | None, torch.device, str def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: """Printout model and list of model parameters with related statistics.""" - if rank == 0 and args.verbose: + if args.verbose: dprint("="*60 + "\n") dprint("\n".join( f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} " diff --git a/scripts/encoders_inference.py b/scripts/encoders_inference.py index 6f3c423..b9d2eda 100644 --- a/scripts/encoders_inference.py +++ b/scripts/encoders_inference.py @@ -40,13 +40,12 @@ # Retrieve linear configuration (quantized or not) to instantiate FMS model linear_config = get_linear_config(args) -if rank == 0: - dprint("="*60) - dprint(f"model_path={args.model_path}") - dprint(f"{linear_config=}") - dprint(f"fused_weights={args.fused_weights}") - dprint(f"data_type={default_dtype}") - dprint("="*60 + "\n") +dprint("="*60) +dprint(f"model_path={args.model_path}") +dprint(f"{linear_config=}") +dprint(f"fused_weights={args.fused_weights}") +dprint(f"data_type={default_dtype}") +dprint("="*60 + "\n") dprint("Loading model...") loading_model_start = time.time() From 55f19d8e30aabf67eb260160b52623f54bed1c92 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:08:05 -0400 Subject: [PATCH 13/22] Update determinism docstring Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 2 +- aiu_fms_testing_utils/utils/model_setup.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index c2f3103..67e736d 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -98,7 +98,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: action="store_true", help=( "`deterministic` requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`" - " when running on CPU or GPU. This flag is ignored on AIU." + " when running on GPU. This flag is ignored on AIU." ), ) args_run_settings.add_argument( diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 705362a..9199ec9 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -85,6 +85,7 @@ def print_system_setup(args: argparse.Namespace) -> None: def set_determinism(args: argparse.Namespace) -> None: """Set determinism. NOTE: torch determinism requires env variable: `CUBLAS_WORKSPACE_CONFIG=:4096:8` + when running on GPU. This env variable is ignored on AIU. """ if args.deterministic: From 2d8b88d117e29b4c64331d84dfc2822fddba346e Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:13:35 -0400 Subject: [PATCH 14/22] Fix typos Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/encoders_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index b016836..14147d5 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -38,7 +38,7 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture: if not has_hf: raise ImportError( - "MaskedLM Encoder requires transformer package but import " + "MaskedLM Encoder requires transformers package but import " "was unsuccessful." ) @@ -161,7 +161,7 @@ def process_eval_set(self) -> None: if not has_hf: raise ImportError( - "QuestionAnswering Encoder requires transformer package but import " + "QuestionAnswering Encoder requires transformers package but import " "was unsuccessful." ) @@ -650,7 +650,7 @@ def process_eval_set(self) -> None: if not has_hf: raise ImportError( - "MaskedLM Encoder requires transformer package but import " + "MaskedLM Encoder requires transformers package but import " "was unsuccessful." ) From 429c57facd3c45d8ca5569fe3d6745a8b08c6307 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:27:13 -0400 Subject: [PATCH 15/22] Update rank-based printouts Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/encoders_utils.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 14147d5..dcd3500 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -19,7 +19,7 @@ import torch # Local Packages -from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank +from aiu_fms_testing_utils.utils.aiu_setup import dprint # Optional imports (required for QA) @@ -231,7 +231,7 @@ def process_eval_set(self) -> None: # DataLoaders creation: if args.pad_to_max_length: - # If padding was already done ot max length, we use the default data collator + # If padding was already done to max length, we use the default data collator # that will just convert everything to tensors. self.data_collator = default_data_collator else: @@ -570,8 +570,7 @@ def run_warmup(self) -> None: ) first_batch = self.convert_batch_to_fms_style(next(iter(dataloader_for_compile))) self.model(**first_batch) - if rank == 0: - dprint(f"Warmup completed in {time.time() - warmup_start_time:.1f} s\n---") + dprint(f"Warmup completed in {time.time() - warmup_start_time:.1f} s\n---") def run_evaluation(self) -> None: """Run QuestionAnswering evaluation.""" @@ -579,9 +578,8 @@ def run_evaluation(self) -> None: args = self.args eval_dataloader = self.eval_dataloader - if rank == 0: - dprint(f"Running evaluation ({len(eval_dataloader)} samples)...") - start_time = time.time() + dprint(f"Running evaluation ({len(eval_dataloader)} samples)...") + start_time = time.time() all_start_logits = [] all_end_logits = [] @@ -669,12 +667,11 @@ def run_evaluation(self, warmup: bool = False) -> None: tokenizer=self.tokenizer.tokenizer, ) output = unmasker(self.prompt) - if rank == 0: - dprint(f"Run completed in {time.time() - warmup_start_time:.1f} s\n---") - if not warmup: - dprint(f"{self.prompt}\nAnswers:") - for ans in output: - dprint(f"{ans['token_str']:10} | {ans['score']:6.4f}") + dprint(f"Run completed in {time.time() - warmup_start_time:.1f} s\n---") + if not warmup: + dprint(f"{self.prompt}\nAnswers:") + for ans in output: + dprint(f"{ans['token_str']:10} | {ans['score']:6.4f}") def run_encoder_eval_qa( From 041f39a9815d3505eeaff493cde006db25094084 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:29:51 -0400 Subject: [PATCH 16/22] Remove superseeded roberta.py script Signed-off-by: Andrea Fasoli --- scripts/roberta.py | 160 --------------------------------------------- 1 file changed, 160 deletions(-) delete mode 100644 scripts/roberta.py diff --git a/scripts/roberta.py b/scripts/roberta.py deleted file mode 100644 index 124b09f..0000000 --- a/scripts/roberta.py +++ /dev/null @@ -1,160 +0,0 @@ -import os -import sys -import argparse -import tempfile - -from aiu_fms_testing_utils.utils import aiu_setup -from aiu_fms_testing_utils.utils.aiu_setup import dprint, world_rank, world_size - - -# PyTorch -import torch -import torch.distributed - -# HuggingFace Transformers -from transformers import ( - AutoModelForMaskedLM, - RobertaTokenizerFast, - pipeline, -) - -# TPEmbedding in FMS uses the torch.ops._c10d_functional.all_gather_into_tensor funciton -# which is not supported by GLOO. Eventhough we don't use GLOO in AIU execution, PyTorch -# doesn't know that and throws an error. -# This should be addressed in a future version of PyTorch, but for now disable it. -os.environ.setdefault("DISTRIBUTED_STRATEGY_IGNORE_MODULES", "WordEmbedding,Embedding") - -# Foundation Model Stack -from fms.models import get_model -from fms.models.hf import to_hf_api - -# Import AIU Libraries -from torch_sendnn import torch_sendnn - -# ============================================================== -# Main -# ============================================================== -if __name__ == "__main__": - # Number of batches to create - NUM_BATCHES=1 - - #------------- - # Command line argument parsing - #------------- - parser = argparse.ArgumentParser(description="PyTorch Small Toy Tensor Parallel Example") - parser.add_argument( "--backend", help="PyTorch Dynamo compiler backend", default='cpu', choices=['cpu', 'aiu']) - pargs = parser.parse_args() - - if pargs.backend == 'aiu': - dynamo_backend = 'sendnn' - else: - dynamo_backend = 'inductor' - - is_distributed = world_size > 1 - if is_distributed: - # Initialize the process group - torch.distributed.init_process_group(backend="gloo", rank=world_rank, world_size=world_size) - # Looks like a string compare, but is actually comparing the components - # https://github.com/pytorch/pytorch/blob/b5be4d8c053e22672719b9a33386b071daf9860d/torch/torch_version.py#L10-L16 - if torch.__version__ < '2.3.0': - # Fix until PyTorch 2.3 - torch._C._distributed_c10d._register_process_group("default", torch.distributed.group.WORLD) - - #------------- - # Setup AIU specific environment variables - #------------- - if "sendnn" in dynamo_backend: - aiu_setup.aiu_dist_setup(world_rank, world_size) - - #------------- - # Display some diagnostics - #------------- - if 0 == world_rank: - dprint("-"*60) - dprint(f"Python Version : {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") - dprint(f"PyTorch Version : {torch.__version__}") - dprint(f"Dynamo Backend : {pargs.backend} -> {dynamo_backend}") - if pargs.backend == 'aiu': - for peer_rank in range(world_size): - pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) - dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") - print("-"*60) - if is_distributed: - torch.distributed.barrier() - - #------------- - # Create the model - #------------- - if 0 == world_rank: - dprint(f"Creating the model...") - # model_name = "roberta-base" - # model_name = "deepset/roberta-base-squad2-distilled" - model_name = "FacebookAI/roberta-base" - hf_model = AutoModelForMaskedLM.from_pretrained(model_name) - with tempfile.TemporaryDirectory() as workdir: - hf_model.save_pretrained( - f"{workdir}/roberta-base-masked_lm", safe_serialization=False - ) - model = get_model( - "roberta", - "base", - f"{workdir}/roberta-base-masked_lm", - "hf", - norm_eps=1e-5, - tie_heads=True, - ) - hf_model_fms = to_hf_api( - model, task_specific_params=hf_model.config.task_specific_params - ) - # hf_model_fms = get_model( - # architecture="hf_pretrained", - # variant=model_name - # ) - - #------------- - # Compile the model - #------------- - if 0 == world_rank: - dprint(f"Compiling the model...") - the_compiled_model = torch.compile(hf_model_fms, backend=dynamo_backend) - the_compiled_model.eval() # inference only mode - torch.set_grad_enabled(False) - - #------------- - # Run the model - # - First run the compiler will activate to create the artifacts - # - Second run there is no compiler involved - #------------- - if is_distributed: - torch.distributed.barrier() - - torch.manual_seed(42) - tokenizer = RobertaTokenizerFast.from_pretrained(model_name) - # prompt = "Hello I'm a model." - # prompt = "Kermit the frog is a ." - prompt = "Miss Piggy is a ." - - # First run will create compiled artifacts - if 0 == world_rank: - dprint(f"Running model: First Time...") - unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer) - the_output = unmasker(prompt) - if 0 == world_rank: - dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}") - - # Second run will be faster as it uses the cached artifacts - if 0 == world_rank: - dprint(f"Running model: Second Time...") - unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer) - the_output = unmasker(prompt) - if 0 == world_rank: - dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}") - - #------------- - # Cleanup - #------------- - if 0 == world_rank: - dprint(f"Done") - if is_distributed: - torch.distributed.barrier() - torch.distributed.destroy_process_group() From 0dfa47266046300aa4f9eba59c83658d0f941054 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 1 Jul 2025 18:42:24 -0400 Subject: [PATCH 17/22] Gate post processing to rank 0 only Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/encoders_utils.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index dcd3500..8445a30 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -19,7 +19,7 @@ import torch # Local Packages -from aiu_fms_testing_utils.utils.aiu_setup import dprint +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank # Optional imports (required for QA) @@ -600,33 +600,34 @@ def run_evaluation(self) -> None: f"bs = {args.batch_size})" ) - # concatenate the numpy array - max_len = max([x.shape[1] for x in all_start_logits]) - start_logits_concat = self.create_and_fill_np_array( - all_start_logits, - self.eval_dataset, - max_len, - ) - end_logits_concat = self.create_and_fill_np_array( - all_end_logits, - self.eval_dataset, - max_len, - ) + if rank == 0: + # concatenate the numpy array + max_len = max([x.shape[1] for x in all_start_logits]) + start_logits_concat = self.create_and_fill_np_array( + all_start_logits, + self.eval_dataset, + max_len, + ) + end_logits_concat = self.create_and_fill_np_array( + all_end_logits, + self.eval_dataset, + max_len, + ) - del all_start_logits - del all_end_logits + del all_start_logits + del all_end_logits - outputs_numpy = (start_logits_concat, end_logits_concat) - prediction = self.post_processing_function( - self.eval_examples, - self.eval_dataset, - outputs_numpy, - ) - eval_metric = self.metric.compute( - predictions=prediction.predictions, - references=prediction.label_ids, - ) - dprint(f"Evaluation metrics: {eval_metric}") + outputs_numpy = (start_logits_concat, end_logits_concat) + prediction = self.post_processing_function( + self.eval_examples, + self.eval_dataset, + outputs_numpy, + ) + eval_metric = self.metric.compute( + predictions=prediction.predictions, + references=prediction.label_ids, + ) + dprint(f"Evaluation metrics: {eval_metric}") class EncoderMLMInfer(): From bca7a39a7e00b7b34b3b64794dec44f33ba02131 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 2 Jul 2025 15:03:16 -0400 Subject: [PATCH 18/22] Rename encoder inference entry point script Signed-off-by: Andrea Fasoli --- scripts/{encoders_inference.py => run_encoders.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{encoders_inference.py => run_encoders.py} (100%) diff --git a/scripts/encoders_inference.py b/scripts/run_encoders.py similarity index 100% rename from scripts/encoders_inference.py rename to scripts/run_encoders.py From e989a233e971170b9de730aa399aedd917100a17 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 2 Jul 2025 17:53:03 -0400 Subject: [PATCH 19/22] Move batch to correct device at eval Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/encoders_utils.py | 9 +++++++++ scripts/run_encoders.py | 1 + 2 files changed, 10 insertions(+) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 8445a30..30cebea 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -45,6 +45,14 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture: model.config.linear_config.pop("linear_type", None) return to_hf_api(model, task_specific_params=None) +def move_to_device(batch: dict, device: torch.device) -> dict: + """Move batch to selected device.""" + + batch_on_device = {} + for k, v in batch.items(): + batch_on_device[k] = v.to(device) + return batch_on_device + class EncoderQAInfer(): """Run QuestionAnswering task with encoder models.""" @@ -587,6 +595,7 @@ def run_evaluation(self) -> None: with torch.no_grad(): dprint(f"Step {step + 1} / {len(eval_dataloader)}") batch = self.convert_batch_to_fms_style(batch) + batch = move_to_device(batch, args.device) start_logits, end_logits = self.model(**batch) all_start_logits.append(start_logits.cpu().numpy()) all_end_logits.append(end_logits.cpu().numpy()) diff --git a/scripts/run_encoders.py b/scripts/run_encoders.py index b9d2eda..47a268b 100644 --- a/scripts/run_encoders.py +++ b/scripts/run_encoders.py @@ -36,6 +36,7 @@ # Main model setup default_dtype, device, dist_strat = setup_model(args) +args.device = device # Retrieve linear configuration (quantized or not) to instantiate FMS model linear_config = get_linear_config(args) From 0cd0d7bdefede13700beb432e27fb7d2b5541642 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 2 Jul 2025 18:56:19 -0400 Subject: [PATCH 20/22] Add 16b forced casting Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 17 ++++++++++ aiu_fms_testing_utils/utils/encoders_utils.py | 4 +-- aiu_fms_testing_utils/utils/model_setup.py | 27 +++++++++++++++ .../utils/quantization_setup.py | 33 +++++++++++++++++++ scripts/run_encoders.py | 12 ++++++- 5 files changed, 90 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index 67e736d..bd546bc 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -50,6 +50,22 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: "weight format by setting the default pytorch format" ), ) + parser.add_argument( + "--cast_bf16_to_fp16", + action="store_true", + help=( + "If set, cast any bf16 weights in the model to fp16 for AIU compiler. " + "Doesn't touch fp32 or quantized" + ) + ) + parser.add_argument( + "--cast_fp16_to_bf16", + action="store_true", + help=( + "If set, cast any fp16 weights in the model to bf16 for GPU. " + "Doesn't touch fp32 or quantized" + ) + ) # Quantization arguments args_quantization = parser.add_argument_group("Model quantization") @@ -260,6 +276,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: args.is_aiu_backend = "aiu" in args.device_type args.dynamo_backend = "sendnn" if args.is_aiu_backend else "inductor" args.fused_weights = not args.unfuse_weights + args.force_16b_dtype = args.cast_bf16_to_fp16 or args.cast_fp16_to_bf16 if args.verbose: dprint("=" * 60) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 30cebea..3ead135 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -597,8 +597,8 @@ def run_evaluation(self) -> None: batch = self.convert_batch_to_fms_style(batch) batch = move_to_device(batch, args.device) start_logits, end_logits = self.model(**batch) - all_start_logits.append(start_logits.cpu().numpy()) - all_end_logits.append(end_logits.cpu().numpy()) + all_start_logits.append(start_logits.to(torch.float16).cpu().numpy()) + all_end_logits.append(end_logits.to(torch.float16).cpu().numpy()) eval_duration = time.time() - start_time dprint( f"Runtime: {eval_duration:.0f} s | " diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 9199ec9..562556b 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -120,6 +120,33 @@ def setup_model(args: argparse.Namespace) -> tuple[str | None, torch.device, str return default_dtype, device, dist_strat +def recast_16b(model: nn.Module, args: argparse.Namespace) -> None: + """Cast 16-bit model parameters to selected datatype.""" + + if args.cast_bf16_to_fp16: + dprint( + "Casting all BF16 model parameters to FP16 " + "(--cast_bf16_to_fp16 flag is enabled)" + ) + for name, param in model.named_parameters(): + if param.dtype == torch.bfloat16: + if param.max() > torch.finfo(torch.float16).max: + dprint( + f"[WARNING] Casting param {name} to fp16 will truncate the " + "tensor. This may cause accuracy loss. Ignore this warning if " + "this is intended." + ) + param.data = param.data.to(dtype=torch.float16) + elif args.cast_fp16_to_bf16: + dprint( + "Casting all FP16 model parameters to BF16 " + "(--cast_fp16_to_bf16 flag is enabled)" + ) + for param in model.parameters(): + if param.dtype == torch.float16: + param.data = param.data.to(dtype=torch.bfloat16) + + def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: """Printout model and list of model parameters with related statistics.""" diff --git a/aiu_fms_testing_utils/utils/quantization_setup.py b/aiu_fms_testing_utils/utils/quantization_setup.py index c7eac38..8b77dba 100644 --- a/aiu_fms_testing_utils/utils/quantization_setup.py +++ b/aiu_fms_testing_utils/utils/quantization_setup.py @@ -6,6 +6,7 @@ import os # Third Party +import torch from torch import nn # Local Packages @@ -133,3 +134,35 @@ def select_int8_module( else: linear_config = {"linear_type": "torch_linear"} return linear_config + + +def validate_quantization(model: nn.Module, args: argparse.Namespace) -> None: + """Ensure compatibility of FP8 models with device-specific operations.""" + + has_fp8_weights = False + has_bf16_weights = False + has_fp16_weights = False + for param in model.parameters(): + if param.dtype == torch.float8_e4m3fn: + has_fp8_weights = True + elif param.dtype == torch.bfloat16: + has_bf16_weights = True + elif param.dtype == torch.float16: + has_fp16_weights = True + + if has_fp8_weights: + if args.is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16: + raise ValueError( + "FP8 checkpoints on AIU with bf16 weights require casting to fp16 " + "using --cast_bf16_to_fp16. Do not use --default_dtype!" + ) + elif ( + args.device.type == "cuda" + and has_fp16_weights + and not args.cast_fp16_to_bf16 + ): + raise ValueError( + "FP8 checkpoints on GPU with fp16 weights require casting to bf16 " + "using --cast_fp16_to_bf16. Do not use --default_dtype!" + ) + diff --git a/scripts/run_encoders.py b/scripts/run_encoders.py index 47a268b..76a08aa 100644 --- a/scripts/run_encoders.py +++ b/scripts/run_encoders.py @@ -16,10 +16,15 @@ run_encoder_eval_qa, run_encoder_eval_mlm, ) -from aiu_fms_testing_utils.utils.model_setup import setup_model, print_model_params +from aiu_fms_testing_utils.utils.model_setup import ( + setup_model, + print_model_params, + recast_16b +) from aiu_fms_testing_utils.utils.quantization_setup import ( import_addons, get_linear_config, + validate_quantization, ) parser = argparse.ArgumentParser( @@ -61,9 +66,14 @@ group=distributed.group.WORLD, linear_config=linear_config, fused_weights=args.fused_weights, + attn_name="math_fp8", ) +if args.force_16b_dtype: + recast_16b(model, args) + if args.is_quantized: + validate_quantization(model, args) print_model_params(model, args) tokenizer = tokenizers.get_tokenizer(args.tokenizer) From 18f4230c5bf8f33ef793ec91d5ab9e608f6be99b Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 3 Jul 2025 12:47:16 -0400 Subject: [PATCH 21/22] Reinstate local_size in aiu_setup (for future use) Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/aiu_setup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/utils/aiu_setup.py b/aiu_fms_testing_utils/utils/aiu_setup.py index d7ee71f..6a449c8 100644 --- a/aiu_fms_testing_utils/utils/aiu_setup.py +++ b/aiu_fms_testing_utils/utils/aiu_setup.py @@ -23,7 +23,7 @@ def dprint(text): # ============================================================== # Common setup # ============================================================== -def aiu_setup(rank=0, world_size=1, local_rank=0, verbose=False): +def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): # ------------- # Envar setup for Sentient backend # ------------- @@ -56,9 +56,11 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, verbose=False): # ============================================================== # Distributed setup # ============================================================== -def aiu_dist_setup(rank, world_size, local_rank=-0, verbose=False): +def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False): if local_rank < 0: local_rank = rank + if local_size < 0: + local_size = world_size if os.getenv("TORCHELASTIC_RUN_ID") is None: os.environ["MASTER_ADDR"] = "localhost" From 2fc4c12aceb22680ceec3718a8d3a4eb3ea1111f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 3 Jul 2025 12:55:48 -0400 Subject: [PATCH 22/22] Add notes about 384 default max_prompt_length Signed-off-by: Andrea Fasoli --- aiu_fms_testing_utils/utils/args_parsing.py | 2 +- aiu_fms_testing_utils/utils/encoders_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index bd546bc..dab0271 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -176,7 +176,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: help=( "Cap the number of tokens per prompt to a maximum length prior to padding. " "If None, prompts to decoder models will have no cap, while prompts to " - "encoder models will be capped to a default of 384 tokens." + "encoder models will be capped to a default of 384 tokens (for QA task)." ), ) diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index 3ead135..d3ef417 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -102,7 +102,7 @@ def prepare_validation_features( max_prompt_length = ( args.max_prompt_length if args.max_prompt_length is not None - else 384 + else 384 # this default is targeted at QA task (not a model limitation) ) # Some of the questions have lots of whitespace on the left, which is not useful