diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..6a3f785 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,15 @@ +name: Lint + +on: [pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 + with: + src: "." + version: "~= 0.9.5" + - run: ruff check + - run: ruff format --check diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 69b6a9b..d421db9 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -6,44 +6,77 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint import os -class LogitsExtractorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]): +class LogitsExtractorHook( + Callable[ + [int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], + Tuple[torch.Tensor, MutableMapping[str, Any]], + ] +): def __init__(self): super().__init__() self.extracted_logits: Optional[torch.Tensor] = None - def __call__(self, token_position: torch.Tensor, logits: torch.Tensor, next_val: torch.Tensor, kwargs): + def __call__( + self, + token_position: torch.Tensor, + logits: torch.Tensor, + next_val: torch.Tensor, + kwargs, + ): if self.extracted_logits is None: self.extracted_logits = logits.unsqueeze(1) else: - self.extracted_logits = torch.cat((self.extracted_logits, logits.unsqueeze(1)), dim=1) + self.extracted_logits = torch.cat( + (self.extracted_logits, logits.unsqueeze(1)), dim=1 + ) return next_val, kwargs -class StaticTokenInjectorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]): - def __init__(self, static_tokens: List[torch.Tensor], device_type: str="cpu"): +class StaticTokenInjectorHook( + Callable[ + [int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], + Tuple[torch.Tensor, MutableMapping[str, Any]], + ] +): + def __init__(self, static_tokens: List[torch.Tensor], device_type: str = "cpu"): super().__init__() - self.static_tokens = torch.tensor(static_tokens, device=device_type).t() # transposing so batch tokens per token_position + self.static_tokens = torch.tensor( + static_tokens, device=device_type + ).t() # transposing so batch tokens per token_position - def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs): + def __call__( + self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs + ): next_val.copy_(self.static_tokens[token_position].unsqueeze(1)) return next_val, kwargs -class GoldenTokenHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]): - def __init__(self, static_tokens: torch.Tensor, device_type: str="cpu"): +class GoldenTokenHook( + Callable[ + [int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], + Tuple[torch.Tensor, MutableMapping[str, Any]], + ] +): + def __init__(self, static_tokens: torch.Tensor, device_type: str = "cpu"): super().__init__() self.logits_extractor = LogitsExtractorHook() self.extracted_logits = None - self.token_injector = StaticTokenInjectorHook(static_tokens, device_type=device_type) + self.token_injector = StaticTokenInjectorHook( + static_tokens, device_type=device_type + ) - def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs): - next_val, kwargs = self.logits_extractor(token_position, logits, next_val, kwargs) + def __call__( + self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs + ): + next_val, kwargs = self.logits_extractor( + token_position, logits, next_val, kwargs + ) self.extracted_logits = self.logits_extractor.extracted_logits return self.token_injector(token_position, logits, next_val, kwargs) -class ValidationInfo: +class ValidationInfo: def __init__(self, validation_info_list): super().__init__() @@ -54,7 +87,10 @@ def __iter__(self): yield vi def get_info(self, info_name): - return [[t.unsqueeze(0) for t in sentence[info_name]] for sentence in self._validation_info_list] + return [ + [t.unsqueeze(0) for t in sentence[info_name]] + for sentence in self._validation_info_list + ] def save(self, save_dir_path: str): """Save the validation information into a directory. @@ -86,12 +122,17 @@ def save(self, save_dir_path: str): def __len__(self): return len(self._validation_info_list) - -def get_default_validation_prefix(model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str): + + +def get_default_validation_prefix( + model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str +): return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}" -def load_validation_information(validation_path, validation_files_type, batch_size, tokenizer=None): +def load_validation_information( + validation_path, validation_files_type, batch_size, tokenizer=None +): """Load the validation information from a directory The files will be assumed to be in the following structure: @@ -107,7 +148,7 @@ def load_validation_information(validation_path, validation_files_type, batch_si if containing only tokens - torch.tensor if containing tokens and logits - dict[tokens -> torch.tensor, logits -> torch.tensor] if containing text - str - + :param validation_path: path to validation info files :param validation_files_type: validation file type to load, one of text, tokens, or logits :param batch_size: the number of prompts to load @@ -115,9 +156,7 @@ def load_validation_information(validation_path, validation_files_type, batch_si :return: a new validation info """ if isinstance(validation_path, str): - validation_files_path, sep, glob_pattern = validation_path.partition( - "*" - ) + validation_files_path, sep, glob_pattern = validation_path.partition("*") else: sep = "" glob_pattern = "" @@ -146,14 +185,14 @@ def load_validation_information(validation_path, validation_files_type, batch_si validation_files_paths = [validation_files_path] # Check if we found some files - assert ( - len(validation_files_paths) > 0 - ), f"Can't find any validation files at {validation_files_path}" + assert len(validation_files_paths) > 0, ( + f"Can't find any validation files at {validation_files_path}" + ) # Check if we have enough files - assert ( - len(validation_files_paths) >= batch_size - ), f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}" + assert len(validation_files_paths) >= batch_size, ( + f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}" + ) validation_info = [] for i, validation_file_path in enumerate(validation_files_paths): @@ -161,7 +200,9 @@ def load_validation_information(validation_path, validation_files_type, batch_si break if validation_files_type == "text": if tokenizer is None: - raise ValueError("must provide a tokenizer when validation_files_type=text") + raise ValueError( + "must provide a tokenizer when validation_files_type=text" + ) # Text format will get tokenized validation_info.append( { @@ -187,7 +228,18 @@ def load_validation_information(validation_path, validation_files_type, batch_si return ValidationInfo(validation_info) -def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **extra_kwargs): + +def extract_validation_information( + model, + input_ids, + max_new_tokens, + post_iteration_hook, + attn_algorithm=None, + eos_token_id=None, + only_last_token=False, + timing="", + **extra_kwargs, +): max_seq_len = model.config.max_expected_seq_len attention_specific_kwargs = {} if "paged" in extra_kwargs["attn_name"]: @@ -195,6 +247,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat else: # TODO: Add a unified generation dependent on attn_type from fms.utils.generation import generate + attention_specific_kwargs["contiguous_cache"] = True attention_specific_kwargs["max_seq_len"] = max_seq_len @@ -215,7 +268,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat eos_token_id=eos_token_id, timing=timing, extra_kwargs=extra_generation_kwargs, - **attention_specific_kwargs + **attention_specific_kwargs, ) if timing != "": @@ -226,7 +279,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat if timing == "e2e": dprint(f"E2E timing information: {timings[0]:.3f}s") elif timing == "per-token": - timings = [f"{t*1000:.3f}" for t in timings] + timings = [f"{t * 1000:.3f}" for t in timings] dprint(f"Per-token timing information: {', '.join(timings)} ms") if len(result.shape) == 1: @@ -234,27 +287,33 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat if hasattr(post_iteration_hook, "extracted_logits"): validation_info = [ - {"tokens": t.to("cpu"), "logits": l.to("cpu")} - for t, l in zip(torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits)) + {"tokens": t.to("cpu"), "logits": logits.to("cpu")} + for t, logits in zip( + torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits) + ) ] else: validation_info = [{"tokens": t.to("cpu")} for t in torch.unbind(result)] return ValidationInfo(validation_info) + def validate_level_0(aiu_tokens_per_sentence, validation_tokens_per_sentence): failed_cases = [] for sentence_idx, (aiu_sentence, validation_sentence) in enumerate( - zip(aiu_tokens_per_sentence, validation_tokens_per_sentence) + zip(aiu_tokens_per_sentence, validation_tokens_per_sentence) ): for token_idx, (aiu_token, validation_token) in enumerate( - zip(aiu_sentence, validation_sentence) + zip(aiu_sentence, validation_sentence) ): if aiu_token != validation_token: failed_cases.append((sentence_idx, token_idx)) return failed_cases -def top_k_loss_calculator(top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float]): + +def top_k_loss_calculator( + top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float] +): """ Function which will take the top_k logits indexes / values from a reference validation info and retrieve the same indexes from the test validation info logits and perform a loss function over the 2 tensors @@ -262,32 +321,38 @@ def top_k_loss_calculator(top_k: int, loss_f: Callable[[torch.Tensor, torch.Tens :param top_k: number of values to take from reference :param loss_f: a loss function between the reference and test logits """ + def loss_func(reference_logits, test_logits): reference_logits_prob = reference_logits.to(dtype=torch.float32) test_logits_prob = test_logits.to(dtype=torch.float32) - reference_values, reference_indices = torch.topk(reference_logits_prob, top_k, dim=1) + reference_values, reference_indices = torch.topk( + reference_logits_prob, top_k, dim=1 + ) test_values = test_logits_prob[:, reference_indices.squeeze(0)] return loss_f(reference_values, test_values) + return loss_func -def capture_level_1_metrics(reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None): +def capture_level_1_metrics( + reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None +): loss_metrics = [] for sentence_idx, (reference_sentence, test_sentence) in enumerate( - zip(reference_logits_per_sentence, test_logits_per_sentence) + zip(reference_logits_per_sentence, test_logits_per_sentence) ): for token_idx, (reference_logits, test_logits) in enumerate( - zip(reference_sentence, test_sentence) + zip(reference_sentence, test_sentence) ): # computing cross entropy loss per token if metrics_calculator is None: loss_fn = torch.nn.CrossEntropyLoss() metrics_value = loss_fn( reference_logits.to(dtype=torch.float32), - test_logits.softmax(dim=1).to(dtype=torch.float32) + test_logits.softmax(dim=1).to(dtype=torch.float32), ) else: metrics_value = metrics_calculator(reference_logits, test_logits) @@ -295,15 +360,16 @@ def capture_level_1_metrics(reference_logits_per_sentence, test_logits_per_sente loss_metrics.append((sentence_idx, token_idx, metrics_value)) return loss_metrics - + + def filter_failed_level_1_cases(level_1_loss_metrics, fail_f, print_failed=False): failed_cases = [] - for (sentence_idx, token_idx, metrics_value) in level_1_loss_metrics: + for sentence_idx, token_idx, metrics_value in level_1_loss_metrics: if fail_f(metrics_value): failed_cases.append((sentence_idx, token_idx, metrics_value)) if print_failed: dprint( - f"In sentence {sentence_idx+1}, the metric for token {token_idx} is {metrics_value}" + f"In sentence {sentence_idx + 1}, the metric for token {token_idx} is {metrics_value}" ) return failed_cases @@ -313,6 +379,12 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer): aiu_token = aiu_tokens[sentence_index][token_index] validation_token = validation_tokens[sentence_index][token_index] - aiu_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(aiu_token)) - validation_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(validation_token)) - print(f"In sentence {sentence_index+1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}") \ No newline at end of file + aiu_str = tokenizer.convert_tokens_to_string( + tokenizer.convert_ids_to_tokens(aiu_token) + ) + validation_str = tokenizer.convert_tokens_to_string( + tokenizer.convert_ids_to_tokens(validation_token) + ) + print( + f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}" + ) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 9299706..51bbfab 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -19,9 +19,10 @@ def warmup_model( max_new_tokens: int, compile_dynamic_sendnn: bool = False, use_cache: bool = True, - **extra_kwargs + **extra_kwargs, ): import torch_sendnn + attention_specific_kwargs = {} attn_name = extra_kwargs["attn_name"] if "paged" in attn_name: @@ -29,6 +30,7 @@ def warmup_model( else: # TODO: Add a unified generation dependent on attn_type from fms.utils.generation import generate + attention_specific_kwargs["contiguous_cache"] = True dprint("AIU warmup") @@ -62,6 +64,7 @@ def warmup_model( pt_compile_model_time = time.time() - pt_compile_model_time dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") + def ids_for_prompt(prompt, tokenizer): tokens = tokenizer.tokenize(prompt) ids = tokenizer.convert_tokens_to_ids(tokens) @@ -70,12 +73,13 @@ def ids_for_prompt(prompt, tokenizer): ids = torch.tensor(ids, dtype=torch.long, device="cpu") return ids + def __download_file(url, filename): try: response = requests.get(url, stream=True) response.raise_for_status() - with open(filename, 'wb') as file: + with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"Successfully downloaded {filename}") @@ -83,13 +87,14 @@ def __download_file(url, filename): except requests.exceptions.RequestException as e: print(f"An error occurred: {e}") + def __sample_requests( prompt_list: List[str], num_requests: int, tokenizer: BaseTokenizer, prompt_length_min: int = 32, prompt_length_max: int = 64, - seed: Optional[int] = None + seed: Optional[int] = None, ): # Shuffle the dataset. if seed is not None: @@ -113,20 +118,24 @@ def __sample_requests( return filtered_dataset + def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: BaseTokenizer, prompt_length_min: int = 32, prompt_length_max: int = 64, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("downloading share-gpt dataset as it does not exist") - __download_file("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", dataset_path) + __download_file( + "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", + dataset_path, + ) # Load the dataset. - with open(dataset_path, encoding='utf-8') as f: + with open(dataset_path, encoding="utf-8") as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] @@ -141,20 +150,21 @@ def sample_sharegpt_requests( seed, ) + def sample_squad_v2_qa_requests( dataset_path: str, num_requests: int, tokenizer: BaseTokenizer, prompt_length_min: int = 32, prompt_length_max: int = 64, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> List[Tuple[str, int]]: from datasets import load_dataset if os.path.exists(dataset_path): - ds = load_dataset(dataset_path)['train'] + ds = load_dataset(dataset_path)["train"] else: - ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train'] + ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)["train"] ds = [f"{data['context']}\n{data['question']}" for data in ds] diff --git a/aiu_fms_testing_utils/utils/aiu_setup.py b/aiu_fms_testing_utils/utils/aiu_setup.py index 6a449c8..2bc55d3 100644 --- a/aiu_fms_testing_utils/utils/aiu_setup.py +++ b/aiu_fms_testing_utils/utils/aiu_setup.py @@ -5,21 +5,24 @@ # ============================================================== # Common utilities # ============================================================== -#------------- +# ------------- # Discover the world size and my rank (envars set by torchrun) # https://pytorch.org/docs/stable/elastic/run.html#environment-variables -#------------- +# ------------- local_rank = int(os.getenv("LOCAL_RANK", 0)) rank = int(os.getenv("RANK", 0)) world_rank = rank world_size = int(os.getenv("WORLD_SIZE", 1)) + def dprint_str(text): return f"[{rank:2d}/{world_size:2d}]: {text}" + def dprint(text): print(dprint_str(text)) + # ============================================================== # Common setup # ============================================================== @@ -48,9 +51,9 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): # ) # directory needs to exist if os.getenv("FLEX_COMPUTE") == "SENTIENT": - dprint(f"Sentient AIU: Enabled") + dprint("Sentient AIU: Enabled") else: - dprint(f"Sentient AIU: Disabled (Senulator)") + dprint("Sentient AIU: Disabled (Senulator)") # ============================================================== @@ -66,7 +69,7 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" elif rank == 0 or verbose: - dprint(f"Detected running via torchrun") + dprint("Detected running via torchrun") aiu_setup(rank, world_size) diff --git a/aiu_fms_testing_utils/utils/args_parsing.py b/aiu_fms_testing_utils/utils/args_parsing.py index dab0271..e526024 100644 --- a/aiu_fms_testing_utils/utils/args_parsing.py +++ b/aiu_fms_testing_utils/utils/args_parsing.py @@ -6,7 +6,6 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: - # Arguments for FMS model loading args_model_loading = parser.add_argument_group("FMS model loading") args_model_loading.add_argument( @@ -36,9 +35,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: args_model_loading.add_argument( "--unfuse_weights", action="store_true", - help=( - "If set to True, this will unfuse any fused weight modules" - ), + help=("If set to True, this will unfuse any fused weight modules"), ) args_model_loading.add_argument( "--default_dtype", @@ -56,7 +53,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: help=( "If set, cast any bf16 weights in the model to fp16 for AIU compiler. " "Doesn't touch fp32 or quantized" - ) + ), ) parser.add_argument( "--cast_fp16_to_bf16", @@ -64,7 +61,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: help=( "If set, cast any fp16 weights in the model to bf16 for GPU. " "Doesn't touch fp32 or quantized" - ) + ), ) # Quantization arguments @@ -84,7 +81,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: type=str, choices=["cuda", "cpu", "aiu", "aiu-senulator"], default="cuda", - help="The device to run the model on" + help="The device to run the model on", ) args_run_settings.add_argument( "--seed", @@ -123,10 +120,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", ) args_run_settings.add_argument( - '-v', '--verbose', - action='store_true', - default=0, - help="Enable verbose output" + "-v", "--verbose", action="store_true", default=0, help="Enable verbose output" ) # Arguments for compilation diff --git a/aiu_fms_testing_utils/utils/encoders_utils.py b/aiu_fms_testing_utils/utils/encoders_utils.py index d3ef417..2f47040 100644 --- a/aiu_fms_testing_utils/utils/encoders_utils.py +++ b/aiu_fms_testing_utils/utils/encoders_utils.py @@ -45,6 +45,7 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture: model.config.linear_config.pop("linear_type", None) return to_hf_api(model, task_specific_params=None) + def move_to_device(batch: dict, device: torch.device) -> dict: """Move batch to selected device.""" @@ -54,7 +55,7 @@ def move_to_device(batch: dict, device: torch.device) -> dict: return batch_on_device -class EncoderQAInfer(): +class EncoderQAInfer: """Run QuestionAnswering task with encoder models.""" def __init__( @@ -108,9 +109,7 @@ def prepare_validation_features( # Some of the questions have lots of whitespace on the left, which is not useful # and will make the truncation of the context fail (the tokenized question will # take a lots of space). So we remove that left whitespace - examples[q_col_name] = [ - q.lstrip() for q in examples[q_col_name] - ] + examples[q_col_name] = [q.lstrip() for q in examples[q_col_name]] # Tokenize our examples with truncation and maybe padding, but keep the overflows # using a stride. This results in one example possible giving several features @@ -162,7 +161,7 @@ def convert_batch_to_fms_style( ) -> dict[str, torch.Tensor]: """FMS uses a different standard than HF for encoder inputs.""" - return {'x': batch['input_ids'], 'mask': batch['attention_mask']} + return {"x": batch["input_ids"], "mask": batch["attention_mask"]} def process_eval_set(self) -> None: """Pre-process evaluation dataset for QuestionAnswering task.""" @@ -195,9 +194,15 @@ def process_eval_set(self) -> None: column_names = raw_datasets["train"].column_names - self.question_column_name = "question" if "question" in column_names else column_names[0] - self.context_column_name = "context" if "context" in column_names else column_names[1] - self.answer_column_name = "answers" if "answers" in column_names else column_names[2] + self.question_column_name = ( + "question" if "question" in column_names else column_names[0] + ) + self.context_column_name = ( + "context" if "context" in column_names else column_names[1] + ) + self.answer_column_name = ( + "answers" if "answers" in column_names else column_names[2] + ) # Padding side determines if we do (question|context) or (context|question) self.pad_on_right = self.tokenizer.padding_side == "right" @@ -314,11 +319,15 @@ def postprocess_qa_predictions( """ if len(predictions) != 2: - raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") + raise ValueError( + "`predictions` should be a tuple with two elements (start_logits, end_logits)." + ) all_start_logits, all_end_logits = predictions if len(predictions[0]) != len(features): - raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") + raise ValueError( + f"Got {len(predictions[0])} predictions and {len(features)} features." + ) # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -333,7 +342,9 @@ def postprocess_qa_predictions( scores_diff_json = collections.OrderedDict() # Logging. - dprint(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + dprint( + f"Post-processing {len(examples)} example predictions split into {len(features)} features." + ) # Let's loop over all the examples! for example_index, example in enumerate(tqdm(examples)): @@ -353,11 +364,16 @@ def postprocess_qa_predictions( offset_mapping = features[feature_index]["offset_mapping"] # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context # available in the current feature. - token_is_max_context = features[feature_index].get("token_is_max_context", None) + token_is_max_context = features[feature_index].get( + "token_is_max_context", None + ) # Update minimum null prediction. feature_null_score = start_logits[0] + end_logits[0] - if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + if ( + min_null_prediction is None + or min_null_prediction["score"] > feature_null_score + ): min_null_prediction = { "offsets": (0, 0), "score": feature_null_score, @@ -366,8 +382,12 @@ def postprocess_qa_predictions( } # Go through all possibilities for the `n_best_size` greater start and end logits. - start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() - end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + start_indexes = np.argsort(start_logits)[ + -1 : -n_best_size - 1 : -1 + ].tolist() + end_indexes = np.argsort(end_logits)[ + -1 : -n_best_size - 1 : -1 + ].tolist() for start_index in start_indexes: for end_index in end_indexes: # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond @@ -382,17 +402,27 @@ def postprocess_qa_predictions( ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. - if end_index < start_index or end_index - start_index + 1 > max_answer_length: + if ( + end_index < start_index + or end_index - start_index + 1 > max_answer_length + ): continue # Don't consider answer that don't have the maximum context available (if such information is # provided). - if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + if ( + token_is_max_context is not None + and not token_is_max_context.get(str(start_index), False) + ): continue prelim_predictions.append( { - "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), - "score": start_logits[start_index] + end_logits[end_index], + "offsets": ( + offset_mapping[start_index][0], + offset_mapping[end_index][1], + ), + "score": start_logits[start_index] + + end_logits[end_index], "start_logit": start_logits[start_index], "end_logit": end_logits[end_index], } @@ -403,7 +433,9 @@ def postprocess_qa_predictions( null_score = min_null_prediction["score"] # Only keep the best `n_best_size` predictions. - predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + predictions = sorted( + prelim_predictions, key=lambda x: x["score"], reverse=True + )[:n_best_size] # Add back the minimum null prediction if it was removed because of its low score. if ( @@ -421,8 +453,18 @@ def postprocess_qa_predictions( # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid # failure. - if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): - predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + if len(predictions) == 0 or ( + len(predictions) == 1 and predictions[0]["text"] == "" + ): + predictions.insert( + 0, + { + "text": "empty", + "start_logit": 0.0, + "end_logit": 0.0, + "score": 0.0, + }, + ) # Compute the softmax of all scores scores = np.array([pred.pop("score") for pred in predictions]) @@ -444,8 +486,14 @@ def postprocess_qa_predictions( best_non_null_pred = predictions[i] # Then we compare to the null prediction using the threshold. - score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] - scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + score_diff = ( + null_score + - best_non_null_pred["start_logit"] + - best_non_null_pred["end_logit"] + ) + scores_diff_json[example["id"]] = float( + score_diff + ) # To be JSON-serializable. if score_diff > null_score_diff_threshold: all_predictions[example["id"]] = "" else: @@ -453,7 +501,14 @@ def postprocess_qa_predictions( # Make `predictions` JSON-serializable by casting np.float back to float. all_nbest_json[example["id"]] = [ - {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + { + k: ( + float(v) + if isinstance(v, (np.float16, np.float32, np.float64)) + else v + ) + for k, v in pred.items() + } for pred in predictions ] @@ -463,14 +518,19 @@ def postprocess_qa_predictions( raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( - output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + output_dir, + "predictions.json" if prefix is None else f"{prefix}_predictions.json", ) nbest_file = os.path.join( - output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + output_dir, + "nbest_predictions.json" + if prefix is None + else f"{prefix}_nbest_predictions.json", ) if version_2_with_negative: null_odds_file = os.path.join( - output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + output_dir, + "null_odds.json" if prefix is None else f"{prefix}_null_odds.json", ) dprint(f"Saving predictions to {prediction_file}.") @@ -568,7 +628,7 @@ def create_and_fill_np_array( def run_warmup(self) -> None: """Run warmup cycle of compiled encoder model set for QuestionAnswering task.""" - dprint(f"Starting warm-up...") + dprint("Starting warm-up...") warmup_start_time = time.time() dataloader_for_compile = DataLoader( self.eval_dataset_for_model, @@ -576,7 +636,9 @@ def run_warmup(self) -> None: collate_fn=self.data_collator, batch_size=1, ) - first_batch = self.convert_batch_to_fms_style(next(iter(dataloader_for_compile))) + first_batch = self.convert_batch_to_fms_style( + next(iter(dataloader_for_compile)) + ) self.model(**first_batch) dprint(f"Warmup completed in {time.time() - warmup_start_time:.1f} s\n---") @@ -639,7 +701,7 @@ def run_evaluation(self) -> None: dprint(f"Evaluation metrics: {eval_metric}") -class EncoderMLMInfer(): +class EncoderMLMInfer: """Run MaskedLM task with encoder models.""" def __init__( @@ -652,7 +714,6 @@ def __init__( self.tokenizer = tokenizer self.args = args - def process_eval_set(self) -> None: """Barebone function that sets up a single example prompt (for now).""" diff --git a/aiu_fms_testing_utils/utils/model_setup.py b/aiu_fms_testing_utils/utils/model_setup.py index 562556b..4912482 100644 --- a/aiu_fms_testing_utils/utils/model_setup.py +++ b/aiu_fms_testing_utils/utils/model_setup.py @@ -47,8 +47,6 @@ def get_device(args: argparse.Namespace) -> torch.device: device = torch.device(args.device_type, local_rank) torch.cuda.set_device(device) elif args.is_aiu_backend: - from torch_sendnn import torch_sendnn - if args.distributed: aiu_setup.aiu_dist_setup( distributed.get_rank(), @@ -67,7 +65,7 @@ def print_system_setup(args: argparse.Namespace) -> None: """Display system info (rank 0 only).""" if args.verbose: - dprint("-"*60) + dprint("-" * 60) dprint( f"Python Version : {sys.version_info.major}." f"{sys.version_info.minor}.{sys.version_info.micro}" @@ -75,11 +73,11 @@ def print_system_setup(args: argparse.Namespace) -> None: dprint(f"PyTorch Version : {torch.__version__}") dprint(f"Dynamo Backend : {args.device_type} -> {args.dynamo_backend}") dprint(f"Distributed : {args.distributed}") - if args.device_type == 'aiu': + if args.device_type == "aiu": for peer_rank in range(aiu_setup.world_size): - pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) + pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank) dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") - dprint("-"*60) + dprint("-" * 60) def set_determinism(args: argparse.Namespace) -> None: @@ -144,20 +142,22 @@ def recast_16b(model: nn.Module, args: argparse.Namespace) -> None: ) for param in model.parameters(): if param.dtype == torch.float16: - param.data = param.data.to(dtype=torch.bfloat16) + param.data = param.data.to(dtype=torch.bfloat16) def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: """Printout model and list of model parameters with related statistics.""" if args.verbose: - dprint("="*60 + "\n") - dprint("\n".join( - f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} " - f"{v.float().min().item():12.4f} {v.float().max().item():12.4f}" - for k,v in model.state_dict().items() - )) - dprint("="*60 + "\n") + dprint("=" * 60 + "\n") + dprint( + "\n".join( + f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} " + f"{v.float().min().item():12.4f} {v.float().max().item():12.4f}" + for k, v in model.state_dict().items() + ) + ) + dprint("=" * 60 + "\n") if args.architecture == "llama": dprint( "[NOTE] In Llama models, it's OK for bias and rotary embeddings to be " @@ -165,4 +165,4 @@ def print_model_params(model: nn.Module, args: argparse.Namespace) -> None: "FMS and HF models (but model output is preserved)." ) dprint(model) - dprint("="*60 + "\n") + dprint("=" * 60 + "\n") diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 771eb01..d80794e 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -3,13 +3,14 @@ import time from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch -import fms.utils.spyre.paged +import fms.utils.spyre.paged # noqa + def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): """ - Adjusts the inputs to a batch. Batch size 1 cannot be handled since we want a symbolic shape for the batch + Adjusts the inputs to a batch. Batch size 1 cannot be handled since we want a symbolic shape for the batch and pytorch automatically sets size 1 dimensions as static - + Note: This is fixed in pytorch 2.7 """ input_ids = input_ids[0].repeat(2, 1) @@ -23,6 +24,7 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): kwargs["position_ids"] = position_ids[0].repeat(2, 1) return input_ids, kwargs + # FIXME: We should use default generate, but that will require a larger re-work of generate def generate( model: Union[Callable, torch.nn.Module], @@ -88,7 +90,7 @@ def generate( if isinstance(input_ids, torch.Tensor): if len(input_ids.shape) == 1: input_ids = input_ids.unsqueeze(0) - + is_batch = input_ids.shape[0] > 1 # our model requires batch dimension if not is_batch: @@ -106,8 +108,18 @@ def generate( result = input_ids next_input = input_ids BLOCK_SIZE = 64 - _MAX_BATCH = int(os.environ.setdefault("VLLM_DT_MAX_BATCH_SIZE", str(input_ids.size(0)))) - _MAX_CONTEXT_LENGTH = int(os.environ.setdefault("VLLM_DT_MAX_CONTEXT_LEN", str((((input_ids.size(1) + max_new_tokens - 1) // BLOCK_SIZE) + 1) * BLOCK_SIZE))) + _MAX_BATCH = int( + os.environ.setdefault("VLLM_DT_MAX_BATCH_SIZE", str(input_ids.size(0))) + ) + _MAX_CONTEXT_LENGTH = int( + os.environ.setdefault( + "VLLM_DT_MAX_CONTEXT_LEN", + str( + (((input_ids.size(1) + max_new_tokens - 1) // BLOCK_SIZE) + 1) + * BLOCK_SIZE + ), + ) + ) NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE max_seq_len = input_ids.size(1) + max_new_tokens if hasattr(model, "head"): @@ -139,18 +151,43 @@ def generate( head_size = model.config.emb_dim // nheads if "fp8" in kwargs["attn_name"]: from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor + kwargs["past_key_value_states"] = [ ( - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), + ScaledTensor( + torch.zeros( + NUM_BLOCKS, + BLOCK_SIZE, + kvheads, + head_size, + dtype=torch.float8_e4m3fn, + ), + torch.tensor(1.0), + False, + ), + ScaledTensor( + torch.zeros( + NUM_BLOCKS, + BLOCK_SIZE, + kvheads, + head_size, + dtype=torch.float8_e4m3fn, + ), + torch.tensor(1.0), + False, + ), ) for _ in range(model.config.nlayers) ] else: kwargs["past_key_value_states"] = [ ( - torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), - torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), + torch.zeros( + NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype + ), + torch.zeros( + NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype + ), ) for _ in range(model.config.nlayers) ] @@ -296,7 +333,7 @@ def generate( v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float("inf") - probs = F.softmax(logits, dim=-1) + probs = F.softmax(logits, dim=-1) # noqa: F821 next_val = torch.multinomial(probs, num_samples=1) else: next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t() @@ -347,4 +384,4 @@ def generate( if timing != "": return result, times - return result \ No newline at end of file + return result diff --git a/aiu_fms_testing_utils/utils/quantization_setup.py b/aiu_fms_testing_utils/utils/quantization_setup.py index 8b77dba..3b2e582 100644 --- a/aiu_fms_testing_utils/utils/quantization_setup.py +++ b/aiu_fms_testing_utils/utils/quantization_setup.py @@ -10,7 +10,7 @@ from torch import nn # Local Packages -from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank +from aiu_fms_testing_utils.utils.aiu_setup import dprint def import_addons(args: argparse.Namespace) -> None: @@ -20,13 +20,13 @@ def import_addons(args: argparse.Namespace) -> None: try: if args.quantization == "gptq" and "aiu" in args.device_type: - from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear + pass elif args.quantization == "fp8": - from fms_mo.aiu_addons.fp8 import fp8_adapter, fp8_attn, fp8_linear + pass elif args.quantization == "int8": - from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear + pass dprint("Loaded `aiu_addons` functionalities") - except: + except ImportError: raise ImportError(f"Failed to import {args.quantization} addons from FMS-MO.") @@ -57,7 +57,7 @@ def get_linear_config(args: argparse.Namespace) -> dict[str, Any]: qconfig_path = args.model_path + "/quantize_config.json" if os.path.exists(qconfig_path): - with open(qconfig_path, 'r') as f: + with open(qconfig_path, "r") as f: dprint(f"loading quantization config from {qconfig_path}") qconfig = json.load(f) group_size = qconfig["group_size"] @@ -111,9 +111,13 @@ def select_int8_module( if args.int8_smoothquant: # TODO: load info from config saved during quantization - if any("granite" in p.lower() for p in [args.model_path, args.architecture]): + if any( + "granite" in p.lower() for p in [args.model_path, args.architecture] + ): smoothquant_layers = ["key", "value", "w1", "wg"] - elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]): + elif any( + "roberta" in p.lower() for p in [args.model_path, args.architecture] + ): smoothquant_layers = ["query", "key", "value", "w1"] else: raise NotImplementedError( @@ -125,8 +129,8 @@ def select_int8_module( linear_config = { "linear_type": partial( select_int8_module, - smoothquant = args.int8_smoothquant, - smoothquant_layers = smoothquant_layers, + smoothquant=args.int8_smoothquant, + smoothquant_layers=smoothquant_layers, ), "weight_per_channel": args.int8_weight_per_channel, "activ_quant_type": args.int8_activ_quant_type, @@ -165,4 +169,3 @@ def validate_quantization(model: nn.Module, args: argparse.Namespace) -> None: "FP8 checkpoints on GPU with fp16 weights require casting to bf16 " "using --cast_fp16_to_bf16. Do not use --default_dtype!" ) - diff --git a/scripts/generate_metrics.py b/scripts/generate_metrics.py index f50ec59..6ecf7a8 100644 --- a/scripts/generate_metrics.py +++ b/scripts/generate_metrics.py @@ -1,14 +1,20 @@ import argparse import ast -import json import os -import random -from typing import List, Optional, Tuple import torch from torch import distributed as dist -from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, get_default_validation_prefix, load_validation_information, print_failed_cases, \ - validate_level_0, GoldenTokenHook, top_k_loss_calculator +from aiu_fms_testing_utils.testing.validation import ( + capture_level_1_metrics, + extract_validation_information, + LogitsExtractorHook, + get_default_validation_prefix, + load_validation_information, + print_failed_cases, + validate_level_0, + GoldenTokenHook, + top_k_loss_calculator, +) from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests from fms.models import get_model from fms.utils import tokenizers @@ -83,19 +89,19 @@ "--topk_per_token", type=int, help="top k values per token to generate loss on", - default=20 + default=20, ) parser.add_argument( "--num_test_tokens_per_sequence", type=int, help="number of tokens in test. For instance, if max_new_tokens=128 and num_test_tokens_per_sequence=256, this means we will generate data over 2 sample prompts. If not set, will be set to max_new_tokens", - default=None + default=None, ) parser.add_argument( "--extra_get_model_kwargs", - nargs='*', + nargs="*", default={}, - help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..." + help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,...", ) parser.add_argument( "--distributed", @@ -105,7 +111,7 @@ parser.add_argument( "--skip_computation", action="store_true", - help="Set this if the output is already assumed to be computed and would like to regenerate metrics without model loading or computation" + help="Set this if the output is already assumed to be computed and would like to regenerate metrics without model loading or computation", ) local_rank = int(os.getenv("LOCAL_RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) @@ -120,14 +126,20 @@ extra_get_model_kwargs = {} for a in args.extra_get_model_kwargs: - a_split = a.split("=") - try: + a_split = a.split("=") + try: extra_get_model_kwargs[a_split[0]] = ast.literal_eval(a_split[1]) - except ValueError: + except ValueError: extra_get_model_kwargs[a_split[0]] = a_split[1] # this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing. -prefix = get_default_validation_prefix(args.variant, args.max_new_tokens, args.batch_size, args.min_pad_length, args.default_dtype) +prefix = get_default_validation_prefix( + args.variant, + args.max_new_tokens, + args.batch_size, + args.min_pad_length, + args.default_dtype, +) if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")): print("skipping metric generation as it has already been done") exit(0) @@ -148,11 +160,12 @@ torch.set_grad_enabled(False) + def find_eos_index(reference_tokens, eos_token_id): result = [] for sentence in reference_tokens: found_eos = False - for token_idx, token in enumerate(sentence[args.min_pad_length:]): + for token_idx, token in enumerate(sentence[args.min_pad_length :]): if token.item() == eos_token_id: found_eos = True result.append(token_idx) @@ -161,13 +174,20 @@ def find_eos_index(reference_tokens, eos_token_id): result.append(args.max_new_tokens) return result -def filter_before_eos(l, filter_indexes): + +def filter_before_eos(metrics, filter_indexes): from itertools import groupby - filtered_results = [list(g)[:filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])] + + filtered_results = [ + list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0]) + ] return [item for sublist in filtered_results for item in sublist] + def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): - prompts_and_sizes = sample_sharegpt_requests(args.sharegpt_path, batch_size, tokenizer, seq_length // 2, seq_length, seed) + prompts_and_sizes = sample_sharegpt_requests( + args.sharegpt_path, batch_size, tokenizer, seq_length // 2, seq_length, seed + ) prompt_list = [] for prompt, _ in prompts_and_sizes: prompt_list.append(ids_for_prompt(prompt, tokenizer)) @@ -175,13 +195,15 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) return input_ids, padding_kwargs -def write_csv(l, path, metric): - with open(path, 'w') as f: - f.write(f'{metric}\n') - for t in l: - f.write(f"{t[2].item()}\n") + +def write_csv(metrics, path, metric_name): + with open(path, "w") as f: + f.write(f"{metric_name}\n") + for t in metrics: + f.write(f"{t[2].item()}\n") f.close() + # prepare the cuda model if not args.skip_computation: cuda_model = get_model( @@ -212,7 +234,9 @@ def write_csv(l, path, metric): cpu_model.eval() print("loaded cpu model") - ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer) + ids, padding_kwargs = __prepare_inputs( + args.batch_size, args.min_pad_length, tokenizer + ) # first test validation level 0 cpu_validation_info = extract_validation_information( @@ -221,7 +245,7 @@ def write_csv(l, path, metric): args.max_new_tokens, LogitsExtractorHook(), attn_algorithm="math", - **padding_kwargs + **padding_kwargs, ) cpu_static_tokens = cpu_validation_info.get_info("tokens") print("extracted cpu validation information") @@ -236,24 +260,41 @@ def write_csv(l, path, metric): args.max_new_tokens, None, only_last_token=True, - **{k: v.to("cuda") for k,v in padding_kwargs.items()} + **{k: v.to("cuda") for k, v in padding_kwargs.items()}, ) cuda_static_tokens = cuda_validation_info.get_info("tokens") failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens) print("extracted cuda validation information level 0") if local_rank == 0: - if len(failed_responses) != 0: - print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer) + if len(failed_responses) != 0: + print_failed_cases( + failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer + ) num_test_tokens_per_sequence = args.num_test_tokens_per_sequence if num_test_tokens_per_sequence is None: num_test_tokens_per_sequence = args.max_new_tokens -cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32)) -prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0) -prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) -diff_mean = lambda r, t: torch.mean(torch.abs(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32))) +cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()( # noqa: E731 + r, t.softmax(dim=1).to(dtype=torch.float32) +) +prob_mean = lambda r, t: torch.mean( # noqa: E731 + ( + r.softmax(dim=1).to(dtype=torch.float32) + / t.softmax(dim=1).to(dtype=torch.float32) + ) + - 1.0 +) +prob_std = lambda r, t: torch.std( # noqa: E731 + r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32) +) +diff_mean = lambda r, t: torch.mean( # noqa: E731 + torch.abs( + r.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(dtype=torch.float32) + ) +) prob_mean_metrics = [] prob_std_metrics = [] @@ -265,10 +306,16 @@ def write_csv(l, path, metric): cuda_path = os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out") if os.path.exists(cpu_path) and os.path.exists(cuda_path): print(f"found the logits at {cpu_path}, reusing") - cpu_validation_info = load_validation_information(cpu_path, "logits", args.batch_size, tokenizer) - cuda_validation_info = load_validation_information(cuda_path, "logits", args.batch_size, tokenizer) + cpu_validation_info = load_validation_information( + cpu_path, "logits", args.batch_size, tokenizer + ) + cuda_validation_info = load_validation_information( + cuda_path, "logits", args.batch_size, tokenizer + ) elif not args.skip_computation: - ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i) + ids, padding_kwargs = __prepare_inputs( + args.batch_size, args.min_pad_length, tokenizer, i + ) # only need to compute this once if we aren't generating more test data if num_test_tokens_per_sequence > args.max_new_tokens: @@ -278,7 +325,7 @@ def write_csv(l, path, metric): args.max_new_tokens, LogitsExtractorHook(), attn_algorithm="math", - **padding_kwargs + **padding_kwargs, ) # generate aiu validation info @@ -288,7 +335,7 @@ def write_csv(l, path, metric): args.max_new_tokens, GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"), only_last_token=True, - **{k: v.to("cuda") for k,v in padding_kwargs.items()} + **{k: v.to("cuda") for k, v in padding_kwargs.items()}, ) print("extracted cuda validation information level 1") @@ -296,8 +343,10 @@ def write_csv(l, path, metric): if local_rank == 0: cpu_validation_info.save(cpu_path) cuda_validation_info.save(cuda_path) - - eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id) + + eos_indexes = find_eos_index( + cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id + ) level_1_metrics = capture_level_1_metrics( cpu_validation_info.get_info("logits"), cuda_validation_info.get_info("logits"), @@ -327,7 +376,21 @@ def write_csv(l, path, metric): prob_diff_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes)) if local_rank == 0: - write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean") - write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std") - write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce") - write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean") + write_csv( + prob_mean_metrics, + os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), + "prob_mean", + ) + write_csv( + prob_std_metrics, + os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), + "prob_std", + ) + write_csv( + prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce" + ) + write_csv( + prob_diff_metrics, + os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), + "diff_mean", + ) diff --git a/scripts/inference.py b/scripts/inference.py index 28c62aa..b0185c5 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -35,7 +35,7 @@ type=str, choices=["cuda", "cpu", "aiu", "aiu-senulator"], default="cuda", - help="The device to run the model on" + help="The device to run the model on", ) parser.add_argument( "--architecture", @@ -222,10 +222,11 @@ help="Number of iterations of inference to perform. Used for variance performance capture.", ) parser.add_argument( - '-v', '--verbose', - action='count', + "-v", + "--verbose", + action="count", default=0, - help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" + help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)", ) parser.add_argument( "--attention_type", @@ -251,20 +252,22 @@ from fms.utils.generation import generate if "fp8" in attn_name: - import fms_mo.aiu_addons.fp8.fp8_attn + pass if args.quantization == "gptq": if "aiu" in args.device_type: try: - from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear + from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa + print("Loaded `aiu_addons` functionalities") - except: + except ImportError: raise ImportError("Failed to import GPTQ addons from fms-mo.") elif args.quantization == "int8": try: - from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear + from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear # noqa + print("Loaded `aiu_addons` functionalities") - except: + except ImportError: raise ImportError("Failed to import INT8 addons from fms-mo.") # this is a test model config @@ -302,7 +305,7 @@ device = torch.device(args.device_type, local_rank) torch.cuda.set_device(device) elif is_aiu_backend: - from torch_sendnn import torch_sendnn + from torch_sendnn import torch_sendnn # noqa if not args.distributed: aiu_setup.aiu_setup(rank, world_size) @@ -377,7 +380,9 @@ fused_weights = not args.unfuse_weights if args.quantization == "gptq": if fused_weights and is_aiu_backend: - raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") + raise ValueError( + "GPTQ checkpoints on AIU must always run with --unfuse_weights" + ) if default_dtype is not None: raise ValueError( "GPTQ default_dtype must be None to preserve the checkpoint data types." @@ -394,7 +399,7 @@ qconfig_path = args.model_path + "/quantize_config.json" if os.path.exists(qconfig_path): - with open(qconfig_path, 'r') as f: + with open(qconfig_path, "r") as f: dprint(f"loading quantization config from {qconfig_path}") qconfig = json.load(f) group_size = qconfig["group_size"] @@ -418,7 +423,9 @@ } elif args.quantization == "int8": if fused_weights and is_aiu_backend: - raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") + raise ValueError( + "INT8 checkpoints on AIU must always run with --unfuse_weights" + ) if default_dtype is not None: raise ValueError( "INT8 default_dtype must be None to preserve the checkpoint data types." @@ -446,17 +453,15 @@ def select_int8_module( elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]): smoothquant_layers = ["query", "key", "value", "w1"] else: - raise NotImplementedError( - "INT8 architecture does not support smoothquant." - ) + raise NotImplementedError("INT8 architecture does not support smoothquant.") else: smoothquant_layers = [] linear_config = { "linear_type": partial( select_int8_module, - smoothquant = args.int8_smoothquant, - smoothquant_layers = smoothquant_layers, + smoothquant=args.int8_smoothquant, + smoothquant_layers=smoothquant_layers, ), "weight_per_channel": args.int8_weight_per_channel, "activ_quant_type": args.int8_activ_quant_type, @@ -464,12 +469,12 @@ def select_int8_module( else: linear_config = {"linear_type": "torch_linear"} -dprint("="*60) +dprint("=" * 60) dprint(f"model_path={args.model_path}") dprint(f"{linear_config=}") dprint(f"{fused_weights=}") dprint(f"data_type={default_dtype}") -dprint("="*60 + "\n") +dprint("=" * 60 + "\n") model = get_model( args.architecture, @@ -500,15 +505,21 @@ def select_int8_module( if has_fp8_weights: if is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16: - raise ValueError("FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!") + raise ValueError( + "FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!" + ) elif device.type == "cuda" and has_fp16_weights and not args.cast_fp16_to_bf16: - raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!") + raise ValueError( + "FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!" + ) if args.cast_bf16_to_fp16: for name, param in model.named_parameters(): if param.dtype == torch.bfloat16: if param.max() > torch.finfo(torch.float16).max: - dprint(f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended.") + dprint( + f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended." + ) param.data = param.data.to(dtype=torch.float16) if args.cast_fp16_to_bf16: @@ -518,13 +529,27 @@ def select_int8_module( if args.quantization in ["gptq", "int8"]: if rank == 0 and args.verbose > 0: - dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) - dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) - dprint("="*60 + "\n") + dprint( + "PARAMS:\n" + + "\n".join( + f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" + for k, v in model.named_parameters() + ) + ) + dprint( + "BUFFERS:\n" + + "\n".join( + f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" + for k, v in model.named_buffers() + ) + ) + dprint("=" * 60 + "\n") if args.architecture == "llama": - dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") + dprint( + "[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys." + ) dprint(model) - dprint("="*60 + "\n") + dprint("=" * 60 + "\n") tokenizer = tokenizers.get_tokenizer(args.tokenizer) model.eval() @@ -535,7 +560,9 @@ def select_int8_module( if args.compile: dprint("compiling model") if is_aiu_backend: - model.compile(backend="sendnn", options={'sendnn.dynamic': args.compile_dynamic_sendnn}) + model.compile( + backend="sendnn", options={"sendnn.dynamic": args.compile_dynamic_sendnn} + ) else: # compiling can make first inference pass slow model.compile(mode=args.compile_mode, backend=args.compile_backend) @@ -591,9 +618,9 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): assert len(prompt_file_paths) > 0, f"Can't find any prompt files at {prompt_path}" # Check if we have enough files - assert ( - len(prompt_file_paths) >= args.batch_size - ), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" + assert len(prompt_file_paths) >= args.batch_size, ( + f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" + ) prompts = [] for i, prompt_file_path in enumerate(prompt_file_paths): @@ -649,7 +676,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len: dprint( - f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set." + "One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set." ) exit(1) prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length) @@ -723,7 +750,7 @@ def infer(use_cache, do_sample, warmup): timing=args.timing, eos_token_id=eos_token_id, extra_kwargs=extra_generation_kwargs, - **attention_specific_kwargs + **attention_specific_kwargs, ) if args.timing != "": result, timings = result @@ -731,14 +758,24 @@ def infer(use_cache, do_sample, warmup): dprint(f"E2E timing information: {timings[0]:.3f}s") elif args.timing == "per-token": if not warmup: - dprint(f"First-token latency: {timings[0]*1000:.3f} ms") - dprint(f"Average next-token latency (including first token): {np.mean(timings)*1000:.3f} ms") + dprint(f"First-token latency: {timings[0] * 1000:.3f} ms") + dprint( + f"Average next-token latency (including first token): {np.mean(timings) * 1000:.3f} ms" + ) if len(timings) > 1: - dprint(f"Average next-token latency: {np.mean(timings[1:])*1000:.3f} ms") - dprint(f"Max next-token latency: {np.max(timings[1:])*1000:.3f} ms (token #{np.argmax(timings[1:]) + 2})") - dprint(f"Min next-token latency: {np.min(timings[1:])*1000:.3f} ms (token #{np.argmin(timings[1:]) + 2})") - dprint(f"Std deviation of next-token latencies: {np.std(timings[1:])*1000:.3f} ms") - timings = [f"{t*1000:.3f}" for t in timings] + dprint( + f"Average next-token latency: {np.mean(timings[1:]) * 1000:.3f} ms" + ) + dprint( + f"Max next-token latency: {np.max(timings[1:]) * 1000:.3f} ms (token #{np.argmax(timings[1:]) + 2})" + ) + dprint( + f"Min next-token latency: {np.min(timings[1:]) * 1000:.3f} ms (token #{np.argmin(timings[1:]) + 2})" + ) + dprint( + f"Std deviation of next-token latencies: {np.std(timings[1:]) * 1000:.3f} ms" + ) + timings = [f"{t * 1000:.3f}" for t in timings] dprint(f"Per-token timing information: {', '.join(timings)} ms") if len(result.shape) == 1: result = result.unsqueeze(0) @@ -754,11 +791,17 @@ def infer(use_cache, do_sample, warmup): ] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)` if args.compile: - dprint(f"compilation warmup") + dprint("compilation warmup") pt_compile_model_time = time.time() if args.device_type == "aiu": # only run warmup for AIU, no need for senulator for cache in use_cache: - warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) + warmup_model( + model, + ids, + args.max_new_tokens, + args.compile_dynamic_sendnn, + **extra_generation_kwargs, + ) aiu_warmup_time = time.time() for sample, cache in itertools.product(do_sample, use_cache): infer(cache, sample, True) @@ -770,7 +813,7 @@ def infer(use_cache, do_sample, warmup): pt_compile_model_time = time.time() - pt_compile_model_time dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") -dprint(f"generating output") +dprint("generating output") for sample, cache in itertools.product(do_sample, use_cache): for _ in range(args.iters): diff --git a/scripts/roberta.py b/scripts/roberta.py new file mode 100644 index 0000000..cc2807f --- /dev/null +++ b/scripts/roberta.py @@ -0,0 +1,173 @@ +import os +import sys +import argparse +import tempfile + +from aiu_fms_testing_utils.utils import aiu_setup +from aiu_fms_testing_utils.utils.aiu_setup import dprint, world_rank, world_size + + +# PyTorch +import torch +import torch.distributed + +# HuggingFace Transformers +from transformers import ( + AutoModelForMaskedLM, + RobertaTokenizerFast, + pipeline, +) + +# TPEmbedding in FMS uses the torch.ops._c10d_functional.all_gather_into_tensor funciton +# which is not supported by GLOO. Eventhough we don't use GLOO in AIU execution, PyTorch +# doesn't know that and throws an error. +# This should be addressed in a future version of PyTorch, but for now disable it. +os.environ.setdefault("DISTRIBUTED_STRATEGY_IGNORE_MODULES", "WordEmbedding,Embedding") + +# Foundation Model Stack +from fms.models import get_model +from fms.models.hf import to_hf_api + +# Import AIU Libraries +from torch_sendnn import torch_sendnn # noqa + +# ============================================================== +# Main +# ============================================================== +if __name__ == "__main__": + # Number of batches to create + NUM_BATCHES = 1 + + # ------------- + # Command line argument parsing + # ------------- + parser = argparse.ArgumentParser( + description="PyTorch Small Toy Tensor Parallel Example" + ) + parser.add_argument( + "--backend", + help="PyTorch Dynamo compiler backend", + default="cpu", + choices=["cpu", "aiu"], + ) + pargs = parser.parse_args() + + if pargs.backend == "aiu": + dynamo_backend = "sendnn" + else: + dynamo_backend = "inductor" + + is_distributed = world_size > 1 + if is_distributed: + # Initialize the process group + torch.distributed.init_process_group( + backend="gloo", rank=world_rank, world_size=world_size + ) + # Looks like a string compare, but is actually comparing the components + # https://github.com/pytorch/pytorch/blob/b5be4d8c053e22672719b9a33386b071daf9860d/torch/torch_version.py#L10-L16 + if torch.__version__ < "2.3.0": + # Fix until PyTorch 2.3 + torch._C._distributed_c10d._register_process_group( + "default", torch.distributed.group.WORLD + ) + + # ------------- + # Setup AIU specific environment variables + # ------------- + if "sendnn" in dynamo_backend: + aiu_setup.aiu_dist_setup(world_rank, world_size) + + # ------------- + # Display some diagnostics + # ------------- + if 0 == world_rank: + dprint("-" * 60) + dprint( + f"Python Version : {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + dprint(f"PyTorch Version : {torch.__version__}") + dprint(f"Dynamo Backend : {pargs.backend} -> {dynamo_backend}") + if pargs.backend == "aiu": + for peer_rank in range(world_size): + pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank) + dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") + print("-" * 60) + if is_distributed: + torch.distributed.barrier() + + # ------------- + # Create the model + # ------------- + if 0 == world_rank: + dprint("Creating the model...") + # model_name = "roberta-base" + # model_name = "deepset/roberta-base-squad2-distilled" + model_name = "FacebookAI/roberta-base" + hf_model = AutoModelForMaskedLM.from_pretrained(model_name) + with tempfile.TemporaryDirectory() as workdir: + hf_model.save_pretrained( + f"{workdir}/roberta-base-masked_lm", safe_serialization=False + ) + model = get_model( + "roberta", + "base", + f"{workdir}/roberta-base-masked_lm", + "hf", + norm_eps=1e-5, + tie_heads=True, + ) + hf_model_fms = to_hf_api( + model, task_specific_params=hf_model.config.task_specific_params + ) + # hf_model_fms = get_model( + # architecture="hf_pretrained", + # variant=model_name + # ) + + # ------------- + # Compile the model + # ------------- + if 0 == world_rank: + dprint("Compiling the model...") + the_compiled_model = torch.compile(hf_model_fms, backend=dynamo_backend) + the_compiled_model.eval() # inference only mode + torch.set_grad_enabled(False) + + # ------------- + # Run the model + # - First run the compiler will activate to create the artifacts + # - Second run there is no compiler involved + # ------------- + if is_distributed: + torch.distributed.barrier() + + torch.manual_seed(42) + tokenizer = RobertaTokenizerFast.from_pretrained(model_name) + # prompt = "Hello I'm a model." + # prompt = "Kermit the frog is a ." + prompt = "Miss Piggy is a ." + + # First run will create compiled artifacts + if 0 == world_rank: + dprint("Running model: First Time...") + unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer) + the_output = unmasker(prompt) + if 0 == world_rank: + dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}") + + # Second run will be faster as it uses the cached artifacts + if 0 == world_rank: + dprint("Running model: Second Time...") + unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer) + the_output = unmasker(prompt) + if 0 == world_rank: + dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}") + + # ------------- + # Cleanup + # ------------- + if 0 == world_rank: + dprint("Done") + if is_distributed: + torch.distributed.barrier() + torch.distributed.destroy_process_group() diff --git a/scripts/run_encoders.py b/scripts/run_encoders.py index 76a08aa..de4a677 100644 --- a/scripts/run_encoders.py +++ b/scripts/run_encoders.py @@ -19,7 +19,7 @@ from aiu_fms_testing_utils.utils.model_setup import ( setup_model, print_model_params, - recast_16b + recast_16b, ) from aiu_fms_testing_utils.utils.quantization_setup import ( import_addons, @@ -46,12 +46,12 @@ # Retrieve linear configuration (quantized or not) to instantiate FMS model linear_config = get_linear_config(args) -dprint("="*60) +dprint("=" * 60) dprint(f"model_path={args.model_path}") dprint(f"{linear_config=}") dprint(f"fused_weights={args.fused_weights}") dprint(f"data_type={default_dtype}") -dprint("="*60 + "\n") +dprint("=" * 60 + "\n") dprint("Loading model...") loading_model_start = time.time() diff --git a/scripts/small-toy.py b/scripts/small-toy.py index a6965e4..ceed2f0 100644 --- a/scripts/small-toy.py +++ b/scripts/small-toy.py @@ -16,7 +16,8 @@ from fms.utils.tp_wrapping import apply_tp # Import AIU Libraries -from torch_sendnn import torch_sendnn +from torch_sendnn import torch_sendnn # noqa + # ============================================================== # Toy Encoder Model @@ -33,21 +34,30 @@ def __init__(self): self._linear_nets = torch.nn.ModuleList() for n in range(self.LAYERS_N): torch.manual_seed(42) - block = FeedForwardBlock(self.INPUT_N, hidden_grow_factor=self.HIDDEN_FACTOR, activation_fn=torch.nn.ReLU(), p_dropout=0) + block = FeedForwardBlock( + self.INPUT_N, + hidden_grow_factor=self.HIDDEN_FACTOR, + activation_fn=torch.nn.ReLU(), + p_dropout=0, + ) self._linear_nets.append(block) self._linear_nets.append(torch.nn.ReLU()) def copy_weights(self, par_model, seq_model): self_parent_layer = self if par_model is None else par_model with torch.no_grad(): - for (seq_name, seq_layer), (self_name, self_layer) in zip(seq_model.named_children(), self_parent_layer.named_children()): + for (seq_name, seq_layer), (self_name, self_layer) in zip( + seq_model.named_children(), self_parent_layer.named_children() + ): if hasattr(self_layer, "load_weights"): - self_layer.load_weights( { - "w1.weight": seq_layer.w1.weight, - "w1.bias": seq_layer.w1.bias, - "w2.weight": seq_layer.w2.weight, - "w2.bias": seq_layer.w2.bias, - }) + self_layer.load_weights( + { + "w1.weight": seq_layer.w1.weight, + "w1.bias": seq_layer.w1.bias, + "w2.weight": seq_layer.w2.weight, + "w2.bias": seq_layer.w2.bias, + } + ) else: self.copy_weights(self_layer, seq_layer) @@ -57,81 +67,95 @@ def forward(self, x): _in = net(_in) return _in + # ============================================================== # Main # ============================================================== if __name__ == "__main__": # Number of batches to create - NUM_BATCHES=1 + NUM_BATCHES = 1 - #------------- + # ------------- # Command line argument parsing - #------------- - parser = argparse.ArgumentParser(description="PyTorch Small Toy Tensor Parallel Example") - parser.add_argument( "--backend", help="PyTorch Dynamo compiler backend", default='cpu', choices=['cpu', 'aiu']) + # ------------- + parser = argparse.ArgumentParser( + description="PyTorch Small Toy Tensor Parallel Example" + ) + parser.add_argument( + "--backend", + help="PyTorch Dynamo compiler backend", + default="cpu", + choices=["cpu", "aiu"], + ) pargs = parser.parse_args() - if pargs.backend == 'aiu': - dynamo_backend = 'sendnn' + if pargs.backend == "aiu": + dynamo_backend = "sendnn" else: - dynamo_backend = 'inductor' + dynamo_backend = "inductor" is_distributed = world_size > 1 if is_distributed: # Initialize the process group - torch.distributed.init_process_group(backend="gloo", rank=world_rank, world_size=world_size) + torch.distributed.init_process_group( + backend="gloo", rank=world_rank, world_size=world_size + ) # Looks like a string compare, but is actually comparing the components # https://github.com/pytorch/pytorch/blob/b5be4d8c053e22672719b9a33386b071daf9860d/torch/torch_version.py#L10-L16 - if torch.__version__ < '2.3.0': + if torch.__version__ < "2.3.0": # Fix until PyTorch 2.3 - torch._C._distributed_c10d._register_process_group("default", torch.distributed.group.WORLD) + torch._C._distributed_c10d._register_process_group( + "default", torch.distributed.group.WORLD + ) - #------------- + # ------------- # Setup AIU specific environment variables - #------------- + # ------------- if "sendnn" in dynamo_backend: aiu_setup.aiu_dist_setup(world_rank, world_size) - #------------- + # ------------- # Display some diagnostics - #------------- + # ------------- if 0 == world_rank: - dprint("-"*60) - dprint(f"Python Version : {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") + dprint("-" * 60) + dprint( + f"Python Version : {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) dprint(f"PyTorch Version : {torch.__version__}") dprint(f"Dynamo Backend : {pargs.backend} -> {dynamo_backend}") - if pargs.backend == 'aiu': + if pargs.backend == "aiu": for peer_rank in range(world_size): - pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) + pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank) dprint(f"PCI Addr. for Rank {peer_rank} : {os.environ[pcie_env_str]}") - print("-"*60) + print("-" * 60) if is_distributed: torch.distributed.barrier() - #------------- + # ------------- # Create the model - #------------- + # ------------- if 0 == world_rank: - dprint(f"Creating the model...") + dprint("Creating the model...") the_model = ToyModelFM() if is_distributed: # Create a Tensor Parallel version of the model apply_tp(the_model, torch.distributed.group.WORLD) - #------------- + # ------------- # Compile the model - #------------- + # ------------- if 0 == world_rank: - dprint(f"Compiling the model...") + dprint("Compiling the model...") the_compiled_model = torch.compile(the_model, backend=dynamo_backend) - the_compiled_model.eval() # inference only mode + the_compiled_model.eval() # inference only mode torch.set_grad_enabled(False) - #------------- + # ------------- # Run the model # - First run the compiler will activate to create the artifacts # - Second run there is no compiler involved - #------------- + # ------------- if is_distributed: torch.distributed.barrier() @@ -140,19 +164,19 @@ def forward(self, x): # First run will create compiled artifacts if 0 == world_rank: - dprint(f"Running model: First Time...") + dprint("Running model: First Time...") the_outputs = the_compiled_model(the_inputs) # Second run will be faster as it uses the cached artifacts if 0 == world_rank: - dprint(f"Running model: Second Time...") + dprint("Running model: Second Time...") the_outputs = the_compiled_model(the_inputs) - #------------- + # ------------- # Cleanup - #------------- + # ------------- if 0 == world_rank: - dprint(f"Done") + dprint("Done") if is_distributed: torch.distributed.barrier() torch.distributed.destroy_process_group() diff --git a/scripts/validation.py b/scripts/validation.py index c5b1449..c6cc898 100644 --- a/scripts/validation.py +++ b/scripts/validation.py @@ -15,7 +15,16 @@ from fms.utils.generation import pad_input_ids from torch import distributed as dist from aiu_fms_testing_utils.utils import warmup_model -from aiu_fms_testing_utils.testing.validation import LogitsExtractorHook, capture_level_1_metrics, extract_validation_information, StaticTokenInjectorHook, GoldenTokenHook, filter_failed_level_1_cases, validate_level_0, load_validation_information, print_failed_cases +from aiu_fms_testing_utils.testing.validation import ( + LogitsExtractorHook, + capture_level_1_metrics, + extract_validation_information, + GoldenTokenHook, + filter_failed_level_1_cases, + validate_level_0, + load_validation_information, + print_failed_cases, +) from aiu_fms_testing_utils.utils import aiu_setup from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size @@ -28,7 +37,7 @@ type=str, choices=["aiu", "aiu-senulator"], default="aiu", - help="The device to run the model on" + help="The device to run the model on", ) parser.add_argument("--validation_device", type=str, default="cpu") parser.add_argument( @@ -212,22 +221,22 @@ "--save_validation_info_path", type=str, default=None, - help="If set, will save the validation info into the path specified for later use" + help="If set, will save the validation info into the path specified for later use", ) parser.add_argument( "--extra_get_model_kwargs", - nargs='*', + nargs="*", default={}, - help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..." + help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,...", ) args = parser.parse_args() extra_get_model_kwargs = {} for a in args.extra_get_model_kwargs: - a_split = a.split("=") - try: + a_split = a.split("=") + try: extra_get_model_kwargs[a_split[0]] = ast.literal_eval(a_split[1]) - except ValueError: + except ValueError: extra_get_model_kwargs[a_split[0]] = a_split[1] # this is a test model config @@ -243,7 +252,9 @@ needs_validation_generation = args.validation_files_path == "" needs_validation_forward = ( - not needs_validation_generation and args.validation_files_type in ["text", "tokens"] and args.validation_level == 1 + not needs_validation_generation + and args.validation_files_type in ["text", "tokens"] + and args.validation_level == 1 ) needs_validation_run = needs_validation_forward or needs_validation_generation @@ -251,11 +262,10 @@ if args.quantization == "gptq": try: - # validation script always loads AIU addon - from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear - print("Loaded `aiu_addons` functionalities") + from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa: F401 + print("Loaded `aiu_addons` functionalities") except ImportError: print("Failed to import addon packages") @@ -284,7 +294,7 @@ aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size()) # Always initialize AIU in this script -from torch_sendnn import torch_sendnn +from torch_sendnn import torch_sendnn # noqa if not args.distributed: aiu_setup.aiu_setup(rank, world_size) @@ -354,7 +364,7 @@ if args.quantization == "gptq": qconfig_path = args.model_path + "/quantize_config.json" if os.path.exists(qconfig_path): - with open(qconfig_path, 'r') as f: + with open(qconfig_path, "r") as f: dprint(f"loading quantization config from {qconfig_path}") qconfig = json.load(f) group_size = qconfig["group_size"] @@ -395,8 +405,10 @@ # model, the adapter will take care of converting key/values from # ckpt into the appropriate form for the model if fused_weights: - raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") - default_dtype=None # GPTQ dtype always comes from ckpt, can't be enforced + raise ValueError( + "GPTQ checkpoints on AIU must always run with --unfuse_weights" + ) + default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced else: linear_config = {"linear_type": "torch_linear"} linear_config_validation = {"linear_type": "torch_linear"} @@ -412,7 +424,7 @@ group=dist.group.WORLD, linear_config=linear_config, fused_weights=fused_weights, - **extra_get_model_kwargs + **extra_get_model_kwargs, ) if args.quantization == "gptq": @@ -422,14 +434,12 @@ "and rotary embeddings, in GPTQ LLaMA models" ) dprint(model) - dprint("="*60 + "\n") + dprint("=" * 60 + "\n") if needs_validation_run: if args.quantization != "gptq": data_type_validation = ( - torch.float32 - if validation_device == aiu_device - else default_dtype + torch.float32 if validation_device == aiu_device else default_dtype ) else: data_type_validation = default_dtype @@ -444,7 +454,7 @@ group=dist.group.WORLD, linear_config=linear_config_validation, fused_weights=fused_weights, - **extra_get_model_kwargs + **extra_get_model_kwargs, ) validation_model.load_state_dict(model.state_dict()) if args.quantization == "gptq": @@ -454,7 +464,7 @@ "rotary embeddings, in GPTQ LLaMA models" ) dprint(validation_model) - dprint("="*60 + "\n") + dprint("=" * 60 + "\n") tokenizer = tokenizers.get_tokenizer(args.tokenizer) model.eval() @@ -526,9 +536,9 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): assert len(prompt_file_paths) > 0, f"Can't find any prompt files at {prompt_path}" # Check if we have enough files - assert ( - len(prompt_file_paths) >= args.batch_size - ), f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" + assert len(prompt_file_paths) >= args.batch_size, ( + f"Not enough prompt files at {prompt_path} for a batch size of {args.batch_size}" + ) prompts = [] for i, prompt_file_path in enumerate(prompt_file_paths): @@ -584,7 +594,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len: dprint( - f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set." + "One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set." ) exit(1) prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length) @@ -594,6 +604,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): ids = prompts padding_kwargs = {} + def print_result(result, result_idx: int = 0, file_prefix: str = ""): if local_rank != 0: return @@ -633,7 +644,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): tokenizer, ) - val_tokens = [torch.tensor(l) for l in validation_info.get_info("tokens")] + val_tokens = [torch.tensor(_) for _ in validation_info.get_info("tokens")] max_val_len = max([prompt.size(0) for prompt in val_tokens]) val_num_gen_tokens = int(args.max_new_tokens) if max_allowed_length is not None: @@ -644,7 +655,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): # Truncate each answer to its prompt length + max_new_tokens for i, prompt in enumerate(prompts): prompt_len = prompt.size(0) - val_tokens[i] = val_tokens[i][:prompt_len+val_num_gen_tokens] + val_tokens[i] = val_tokens[i][: prompt_len + val_num_gen_tokens] if has_padding: val_ids, padding_val_kwargs = pad_input_ids( @@ -683,10 +694,12 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): args.max_new_tokens, LogitsExtractorHook(), attn_algorithm="math", - **padding_kwargs + **padding_kwargs, ) -warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **padding_kwargs) +warmup_model( + model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **padding_kwargs +) ### AIU generation loop static_tokens = validation_info.get_info("tokens") @@ -699,10 +712,10 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): ids, args.max_new_tokens, post_iteration_hook, - eos_token_id = None if args.no_early_termination else tokenizer.eos_token_id, + eos_token_id=None if args.no_early_termination else tokenizer.eos_token_id, only_last_token=True, timing=args.timing, - **padding_kwargs + **padding_kwargs, ) if args.save_validation_info_path is not None: @@ -714,11 +727,12 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""): failed_cases = validate_level_0(aiu_static_tokens, static_tokens) else: level_1_metrics = capture_level_1_metrics( - validation_info.get_info("logits"), - aiu_validation_info.get_info("logits") + validation_info.get_info("logits"), aiu_validation_info.get_info("logits") ) - failed_cases = filter_failed_level_1_cases(level_1_metrics, lambda m: m >= args.logits_loss_threshold) + failed_cases = filter_failed_level_1_cases( + level_1_metrics, lambda m: m >= args.logits_loss_threshold + ) validation_passed = len(failed_cases) == 0 diff --git a/tests/models/conftest.py b/tests/models/conftest.py index e93db8f..bbb612f 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -2,7 +2,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import aiu_setup, rank, world_size import os -import pytest + def pytest_sessionstart(session): """ @@ -23,6 +23,7 @@ def pytest_sessionstart(session): os.environ.setdefault("DTLOG_LEVEL", "error") os.environ.setdefault("DT_DEEPRT_VERBOSE", "-1") + def pytest_addoption(parser): parser.addoption( "--runslow", action="store_true", default=False, help="run slow tests" @@ -43,4 +44,3 @@ def pytest_generate_tests(metafunc): option_value = metafunc.config.option.capture_expectation if "capture_expectation" in metafunc.fixturenames and option_value is not None: metafunc.parametrize("capture_expectation", [option_value]) - diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 9ec7b93..bc935c9 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -28,13 +28,15 @@ import os try: - from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear + from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa: F401 GPTQ_ENABLED = True except ImportError: GPTQ_ENABLED = False -MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models") +MICRO_MODELS_HOME = os.environ.get( + "FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models" +) # Add models to test here LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct" @@ -44,11 +46,19 @@ LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct" micro_model_mapping = { - LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000"), - GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"), + LLAMA_3p1_8B_INSTRUCT: os.path.join( + MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000" + ), + GRANITE_3p2_8B_INSTRUCT: os.path.join( + MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000" + ), # FIXME: Because this uses the same config as 3.2, re-using here, but should update - GRANITE_3p3_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.3-8b-layers-3-step-100000"), - LLAMA_3p1_70B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000") + GRANITE_3p3_8B_INSTRUCT: os.path.join( + MICRO_MODELS_HOME, "granite-3.3-8b-layers-3-step-100000" + ), + LLAMA_3p1_70B_INSTRUCT: os.path.join( + MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000" + ), } SHARE_GPT_DATASET_PATH = os.environ.get( @@ -135,14 +145,18 @@ for metric in skip_assertions.split(","): metric = metric.lower() if metric not in {"ce", "mean_diff"}: - pytest.fail("FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff") + pytest.fail( + "FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff" + ) _skip_assertions.append(metric) skip_assertions = set(_skip_assertions) compile_dynamic_sendnn = ATTN_TYPE == "paged" if compile_dynamic_sendnn: - os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str((((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64) + os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( + (((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64 + ) os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(common_batch_sizes)) common_shapes = list( @@ -272,11 +286,11 @@ def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens) return result -def __filter_before_eos(l, filter_indexes): +def __filter_before_eos(metrics, filter_indexes): from itertools import groupby filtered_results = [ - list(g)[: filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0]) + list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0]) ] return [item for sublist in filtered_results for item in sublist] @@ -302,8 +316,10 @@ def __load_validation_info( else: return None + class PersistentModel: """This class will either get a model that is pre-compiled (if compile_dynamic_sendnn) or re-create the model for each test""" + def __init__(self): self.model = None @@ -318,15 +334,17 @@ def get_or_create(self, is_gptq, **kwargs): self.__maybe_reset_model(model, is_gptq) model.eval() - model.compile(backend="sendnn", options={'sendnn.dynamic': compile_dynamic_sendnn}) + model.compile( + backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn} + ) if compile_dynamic_sendnn: self.model = model - + return model else: return self.model - + # TODO: This was added as we require a special reset for gptq models. Ideally, we would be able to do something like this reset when calling reset_parameters() on the model # however the gptq modules are yet to support this @staticmethod @@ -352,6 +370,7 @@ def __maybe_reset_model(model, is_gptq): res /= 20.0 param.copy_(res) + @pytest.fixture def persistent_model(): return PersistentModel() @@ -360,7 +379,9 @@ def persistent_model(): @pytest.mark.parametrize( "model_path,batch_size,seq_length,max_new_tokens", common_shapes ) -def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persistent_model): +def test_common_shapes( + model_path, batch_size, seq_length, max_new_tokens, persistent_model +): torch.manual_seed(42) torch.set_grad_enabled(False) os.environ["COMPILATION_MODE"] = "offline_decoder" @@ -404,7 +425,9 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi tokenizer = tokenizers.get_tokenizer(model_path) # prepare the AIU model - model = persistent_model.get_or_create(is_gptq, **gptq_kwargs_aiu, **get_model_kwargs) + model = persistent_model.get_or_create( + is_gptq, **gptq_kwargs_aiu, **get_model_kwargs + ) # prepare the cpu model validation_model = get_model( @@ -425,7 +448,9 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi extra_kwargs["attn_name"] = ATTN_NAME # warmup aiu model - warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs) + warmup_model( + model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs + ) # generate cpu validation info cpu_validation_info = __load_validation_info( @@ -457,7 +482,12 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi # first test validation level 0 aiu_validation_info = extract_validation_information( - model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs + model, + input_ids, + max_new_tokens, + None, + only_last_token="paged" not in ATTN_NAME, + **extra_kwargs, ) dprint("aiu validation info extracted for validation level 0") @@ -470,7 +500,6 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: - if failed_validation_level_0: dprint("failed validation level 0, testing validation level 1") else: @@ -563,7 +592,10 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds if USE_MICRO_MODELS: ce_threshold, diff_threshold = fail_thresholds.get( - (model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold) + (model_path, True), + fail_thresholds.get( + (model_path, False), default_metrics_threshold + ), ) else: ce_threshold, diff_threshold = fail_thresholds.get( diff --git a/tests/models/test_encoders.py b/tests/models/test_encoders.py index 67a032c..f9f743a 100644 --- a/tests/models/test_encoders.py +++ b/tests/models/test_encoders.py @@ -1,4 +1,7 @@ -from fms.testing.comparison import ModelSignatureParams, compare_model_signatures, get_signature +from fms.testing.comparison import ( + ModelSignatureParams, + get_signature, +) from fms.utils import tokenizers import pytest from fms.models import get_model @@ -13,11 +16,17 @@ # Add models to test here ROBERTA_SQUAD_V2 = "deepset/roberta-base-squad2" -SQUAD_V2_DATASET_PATH = os.environ.get("SQUAD_V2_DATASET_PATH", os.path.expanduser("~/squad_v2")) -common_model_paths = os.environ.get("FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [ROBERTA_SQUAD_V2]) +SQUAD_V2_DATASET_PATH = os.environ.get( + "SQUAD_V2_DATASET_PATH", os.path.expanduser("~/squad_v2") +) +common_model_paths = os.environ.get( + "FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [ROBERTA_SQUAD_V2] +) common_batch_sizes = os.environ.get("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1, 2, 4, 8]) common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 512]) -validation_diff_threshold = os.environ.get("FMS_TEST_SHAPES_VALIDATION_DIFF_THRESHOLD", .01) +validation_diff_threshold = os.environ.get( + "FMS_TEST_SHAPES_VALIDATION_DIFF_THRESHOLD", 0.01 +) # pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/roberta,/tmp/models/roberta-base-squad2" if isinstance(common_model_paths, str): @@ -36,18 +45,30 @@ if isinstance(validation_diff_threshold, str): validation_diff_threshold = float(validation_diff_threshold) -common_shapes = list(itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths)) +common_shapes = list( + itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths) +) def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): - prompts_and_sizes = sample_squad_v2_qa_requests(SQUAD_V2_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed) + prompts_and_sizes = sample_squad_v2_qa_requests( + SQUAD_V2_DATASET_PATH, + batch_size, + tokenizer, + int(seq_length / 2), + seq_length, + seed, + ) prompt_list = [] for prompt, _ in prompts_and_sizes: prompt_list.append(ids_for_prompt(prompt, tokenizer)) - input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length, is_causal_mask=False) + input_ids, padding_kwargs = pad_input_ids( + prompt_list, min_pad_length=seq_length, is_causal_mask=False + ) return input_ids, padding_kwargs + def __generate_diffs(model_params_1, model_params_2): model_params_1.model.eval() model_params_2.model.eval() @@ -57,7 +78,7 @@ def __generate_diffs(model_params_1, model_params_2): optional_params=model_params_1.other_params, logits_getter_fn=model_params_1.logits_getter_fn, inp=model_params_1.inp, - device=model_params_1.inp.device + device=model_params_1.inp.device, ) signature2 = get_signature( model_params_2.model, @@ -65,7 +86,7 @@ def __generate_diffs(model_params_1, model_params_2): optional_params=model_params_2.other_params, logits_getter_fn=model_params_2.logits_getter_fn, inp=model_params_2.inp, - device=model_params_2.inp.device + device=model_params_2.inp.device, ) signature = np.array(signature) @@ -73,21 +94,25 @@ def __generate_diffs(model_params_1, model_params_2): return np.mean(np.abs(signature2 - signature)) + @pytest.fixture(autouse=True) def reset_compiler(): - yield # run the test + yield # run the test torch.compiler.reset() torch._dynamo.reset() - os.environ.pop('COMPILATION_MODE', None) + os.environ.pop("COMPILATION_MODE", None) + @pytest.mark.parametrize("model_path,batch_size,seq_length", common_shapes) def test_common_shapes(model_path, batch_size, seq_length): os.environ["COMPILATION_MODE"] = "offline" - - dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}") + + dprint( + f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}" + ) tokenizer = tokenizers.get_tokenizer(model_path) - + if os.path.exists(model_path): model_path_kwargs = {"model_path": model_path} else: @@ -98,7 +123,7 @@ def test_common_shapes(model_path, batch_size, seq_length): architecture="hf_pretrained", device_type="cpu", fused_weights=False, - **model_path_kwargs + **model_path_kwargs, ) model.eval() @@ -111,34 +136,56 @@ def test_common_shapes(model_path, batch_size, seq_length): device_type="cpu", data_type=torch.float32, fused_weights=False, - **model_path_kwargs + **model_path_kwargs, ) # prepare input_ids input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) # warmup model - logits_getter_fn = lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1) - aiu_msp = ModelSignatureParams(model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) - get_signature(aiu_msp.model, aiu_msp.params, aiu_msp.inp, aiu_msp.other_params, aiu_msp.logits_getter_fn) + logits_getter_fn = ( # noqa: E731 + lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1) + ) + aiu_msp = ModelSignatureParams( + model, + ["x"], + logits_getter_fn=logits_getter_fn, + inp=input_ids, + other_params=padding_kwargs, + ) + get_signature( + aiu_msp.model, + aiu_msp.params, + aiu_msp.inp, + aiu_msp.other_params, + aiu_msp.logits_getter_fn, + ) # get the average diff over multiple samples diffs = [] for i in range(20): # prepare input_ids - input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer, seed=i) + input_ids, padding_kwargs = __prepare_inputs( + batch_size, seq_length, tokenizer, seed=i + ) aiu_msp = ModelSignatureParams( - model, - ["x"], - logits_getter_fn=logits_getter_fn, - inp=input_ids, - other_params=padding_kwargs + model, + ["x"], + logits_getter_fn=logits_getter_fn, + inp=input_ids, + other_params=padding_kwargs, + ) + cpu_msp = ModelSignatureParams( + validation_model, + ["x"], + logits_getter_fn=logits_getter_fn, + inp=input_ids, + other_params=padding_kwargs, ) - cpu_msp = ModelSignatureParams(validation_model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) diffs.append(__generate_diffs(aiu_msp, cpu_msp)) abs_mean_diff = sum(diffs) / len(diffs) print(f"absolute mean diff: {abs_mean_diff}") - assert abs_mean_diff < validation_diff_threshold \ No newline at end of file + assert abs_mean_diff < validation_diff_threshold diff --git a/tests/models/test_model_expectations.py b/tests/models/test_model_expectations.py index 5cfcd57..28f0d46 100644 --- a/tests/models/test_model_expectations.py +++ b/tests/models/test_model_expectations.py @@ -1,8 +1,6 @@ from fms.models import get_model from fms.utils.generation import pad_input_ids -from fms.utils.tokenizers import get_tokenizer import pytest -from aiu_fms_testing_utils.utils import sample_squad_v2_qa_requests import torch from fms.testing._internal.model_test_suite import ( @@ -20,7 +18,12 @@ MISTRAL_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.3" ROBERTA_SQUAD_v2 = "deepset/roberta-base-squad2" -micro_models = {LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_GUARDIAN_3p1_8B, MISTRAL_7B_INSTRUCT} +micro_models = { + LLAMA_3p1_8B_INSTRUCT, + GRANITE_3p2_8B_INSTRUCT, + GRANITE_GUARDIAN_3p1_8B, + MISTRAL_7B_INSTRUCT, +} class AIUModelFixtureMixin(ModelFixtureMixin): @@ -51,7 +54,12 @@ def model(self, uninitialized_model): return uninitialized_model -decoder_models = [LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_GUARDIAN_3p1_8B, MISTRAL_7B_INSTRUCT] +decoder_models = [ + LLAMA_3p1_8B_INSTRUCT, + GRANITE_3p2_8B_INSTRUCT, + GRANITE_GUARDIAN_3p1_8B, + MISTRAL_7B_INSTRUCT, +] class TestAIUDecoderModels( diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index e0f2ab4..01d57ab 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -1,10 +1,12 @@ -import pytest, os +import pytest +import os from subprocess import Popen, PIPE from pathlib import Path import itertools import math + FMS_DIR = Path(__file__).parent -AIU_FMS_DIR = os.path.join(FMS_DIR,"../../../aiu-fms-testing-utils/") +AIU_FMS_DIR = os.path.join(FMS_DIR, "../../../aiu-fms-testing-utils/") VALIDATION_FILE_PATH = os.path.join(AIU_FMS_DIR, "scripts", "validation.py") INFERENCE_FILE_PATH = os.path.join(AIU_FMS_DIR, "scripts", "inference.py") @@ -17,40 +19,68 @@ GRANITE_3_8B_CODE_BASE = f"{model_dir}/granite-3-8b-base" # pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" -if os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS") == None or os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS") == "": +if ( + os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS") is None + or os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS") == "" +): common_model_paths = [LLAMA_194M] else: - common_model_paths = os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS").split(',') + common_model_paths = os.environ.get("FMS_TESTING_COMMON_MODEL_PATHS").split(",") -common_batch_sizes = [1,8] +common_batch_sizes = [1, 8] common_seq_lengths = [64] common_max_new_tokens = [8] -common_params = list(itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths, common_max_new_tokens)) +common_params = list( + itertools.product( + common_model_paths, + common_batch_sizes, + common_seq_lengths, + common_max_new_tokens, + ) +) common_asserts = [ - "### Response: Chicken soup is a popular soup that is", - "### Response: I am sorry, but I am not", - "### Response: I am ignorant of the fact that I", - "### Response: I have just come into a very large", - ] + "### Response: Chicken soup is a popular soup that is", + "### Response: I am sorry, but I am not", + "### Response: I am ignorant of the fact that I", + "### Response: I have just come into a very large", +] current_env = os.environ.copy() -current_env["DT_OPT"]="varsub=1,lxopt=1,opfusion=1,arithfold=1,dataopt=1,patchinit=1,patchprog=1,autopilot=1,weipreload=0,kvcacheopt=1,progshareopt=1" +current_env["DT_OPT"] = ( + "varsub=1,lxopt=1,opfusion=1,arithfold=1,dataopt=1,patchinit=1,patchprog=1,autopilot=1,weipreload=0,kvcacheopt=1,progshareopt=1" +) -def execute_script(execute_cmd): - current_env['MAX_SHAREDPROG_ITERS'] = f"{common_max_new_tokens[0]}" - with Popen(execute_cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, universal_newlines=True, env=current_env) as p: +def execute_script(execute_cmd): + current_env["MAX_SHAREDPROG_ITERS"] = f"{common_max_new_tokens[0]}" + + with Popen( + execute_cmd, + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + universal_newlines=True, + env=current_env, + ) as p: output, error = p.communicate() if p.returncode == 0: return output else: raise Exception(error) + # we are forcing the number of layers to be 2 to reduce the size of the model as we do not care about output, but just consistency between cpu and aiu -def execute_validation(validation_level, model_path, max_new_tokens, batch_size, seq_length, logits_loss_threshold=0.0): +def execute_validation( + validation_level, + model_path, + max_new_tokens, + batch_size, + seq_length, + logits_loss_threshold=0.0, +): execute_cmd = [ - 'python3', + "python3", VALIDATION_FILE_PATH, "--architecture=hf_pretrained", f"--model_path={model_path}", @@ -62,13 +92,14 @@ def execute_validation(validation_level, model_path, max_new_tokens, batch_size, "--no_early_termination", f"--validation_level={validation_level}", f"--logits_loss_threshold={logits_loss_threshold}", - "--compile_dynamic" + "--compile_dynamic", ] return execute_script(execute_cmd) + def execute_inference(model_path, max_new_tokens, batch_size, seq_length): execute_cmd = [ - 'python3', + "python3", INFERENCE_FILE_PATH, "--architecture=hf_pretrained", f"--model_path={model_path}", @@ -80,23 +111,24 @@ def execute_inference(model_path, max_new_tokens, batch_size, seq_length): "--no_early_termination", "--compile_dynamic", "--compile", - "--device_type=aiu" + "--device_type=aiu", ] return execute_script(execute_cmd) -@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_params) + +@pytest.mark.parametrize( + "model_path,batch_size,seq_length,max_new_tokens", common_params +) def test_level_1_validation_script(model_path, batch_size, seq_length, max_new_tokens): result_text = execute_validation( - 1, - model_path, - max_new_tokens, - batch_size, - seq_length, - 64.0 + 1, model_path, max_new_tokens, batch_size, seq_length, 64.0 ) assert "The validation has passed!" in result_text -@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_params) + +@pytest.mark.parametrize( + "model_path,batch_size,seq_length,max_new_tokens", common_params +) def test_level_0_validation_script(model_path, batch_size, seq_length, max_new_tokens): result_text = execute_validation( 0, @@ -107,6 +139,7 @@ def test_level_0_validation_script(model_path, batch_size, seq_length, max_new_t ) assert "The validation has passed!" in result_text + common_asserts = [ "### Response: Chicken soup is a popular soup that is", "### Response: I am sorry, but I am not", @@ -114,18 +147,25 @@ def test_level_0_validation_script(model_path, batch_size, seq_length, max_new_t "### Response: I have just come into a very large", ] + def __repeat_batch_asserts(bs: int) -> list[str]: n_repeats = int(math.ceil(bs / len(common_asserts))) return (common_asserts * n_repeats)[:bs] + # add the asserts based on batch size # for batches greater than common_asserts, repeat common_asserts since this follows inference behavior -common_inference_params = [common_param + (__repeat_batch_asserts(common_param[1]),) for common_param in common_params] +common_inference_params = [ + common_param + (__repeat_batch_asserts(common_param[1]),) + for common_param in common_params +] -@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,asserts", common_inference_params) +@pytest.mark.parametrize( + "model_path,batch_size,seq_length,max_new_tokens,asserts", common_inference_params +) def test_inference_script(model_path, max_new_tokens, seq_length, batch_size, asserts): result_text = execute_inference(model_path, max_new_tokens, batch_size, seq_length) for common_assert in asserts: - assert common_assert in result_text \ No newline at end of file + assert common_assert in result_text diff --git a/tests/resources/get_thresholds.py b/tests/resources/get_thresholds.py index 7dedb70..b4dee45 100644 --- a/tests/resources/get_thresholds.py +++ b/tests/resources/get_thresholds.py @@ -2,25 +2,22 @@ import os import numpy as np import argparse -import os -parser = argparse.ArgumentParser( - description="Script to get thresholds metrics" -) +parser = argparse.ArgumentParser(description="Script to get thresholds metrics") parser.add_argument( "--models", type=str, default=[], - nargs='+', + nargs="+", required=True, - help="List of models id separated by space. Eg.: ibm-granite/granite-20b-code-instruct-8k /tmp/models/granite-20b-code-cobol-v1" + help="List of models id separated by space. Eg.: ibm-granite/granite-20b-code-instruct-8k /tmp/models/granite-20b-code-cobol-v1", ) parser.add_argument( "--metrics", type=str, default=[], - nargs='+', + nargs="+", required=True, help="List of metrics separated by space. Eg.: diff_mean ce", ) @@ -43,7 +40,6 @@ metric_list = [] for metric_file in metric_files: - with open(metric_file, "r") as file: next(file) for line in file: diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index 047590e..90cf2e7 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -1,11 +1,19 @@ import tempfile import pytest -from aiu_fms_testing_utils.testing.validation import LogitsExtractorHook, extract_validation_information, load_validation_information +from aiu_fms_testing_utils.testing.validation import ( + LogitsExtractorHook, + extract_validation_information, + load_validation_information, +) from fms.models import get_model from fms.utils.generation import pad_input_ids import torch -@pytest.mark.parametrize("validation_type,post_iteration_hook", [("logits", LogitsExtractorHook()), ("tokens", None)]) + +@pytest.mark.parametrize( + "validation_type,post_iteration_hook", + [("logits", LogitsExtractorHook()), ("tokens", None)], +) def test_validation_info_round_trip(validation_type, post_iteration_hook): # prepare a small cpu model model = get_model( @@ -22,7 +30,11 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): # prepare input_ids prompt_list = [] for i in range(batch_size): - prompt_list.append(torch.randint(0, model.config.src_vocab_size, (seq_length - 2 * i,), dtype=torch.long)) + prompt_list.append( + torch.randint( + 0, model.config.src_vocab_size, (seq_length - 2 * i,), dtype=torch.long + ) + ) input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) @@ -33,14 +45,16 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): max_new_tokens, post_iteration_hook, attn_algorithm="math", - **padding_kwargs + **padding_kwargs, ) with tempfile.TemporaryDirectory() as workdir: output_path = f"{workdir}/validation_info" generated_validation_info.save(output_path) - loaded_validation_info = load_validation_information(output_path, validation_type, batch_size) + loaded_validation_info = load_validation_information( + output_path, validation_type, batch_size + ) assert len(generated_validation_info) == len(loaded_validation_info) diff --git a/tests/utils/test_paged.py b/tests/utils/test_paged.py index 519042a..9d5281c 100644 --- a/tests/utils/test_paged.py +++ b/tests/utils/test_paged.py @@ -1,12 +1,9 @@ import torch from fms.models import get_model -from fms.utils.generation import ( - pad_input_ids, - generate -) +from fms.utils.generation import pad_input_ids, generate from aiu_fms_testing_utils.utils.paged import generate as paged_generate from fms.utils.tokenizers import get_tokenizer -import pytest + def test_paged_equivalence(): torch.manual_seed(0)