From f37b5c2056122d9b214520f812e8b8f76d692363 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Tue, 1 Apr 2025 14:29:15 +0000 Subject: [PATCH 1/2] added fixes for handling multiple shape warmup with dynamic shapes; added tests for multiple shape warmup Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/testing/validation.py | 3 +- aiu_fms_testing_utils/utils/__init__.py | 26 +++++++- tests/models/test_decoders.py | 70 ++++++++++++++++++++- 3 files changed, 96 insertions(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 6c3d6300..be7eb94b 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -3,7 +3,7 @@ import torch from fms.utils.generation import generate -from aiu_fms_testing_utils.utils import ids_for_prompt +from aiu_fms_testing_utils.utils import ids_for_prompt, _prepare_model_inputs_hook from aiu_fms_testing_utils.utils.aiu_setup import dprint import os @@ -206,6 +206,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat timing=timing, contiguous_cache=True, extra_kwargs=extra_generation_kwargs, + prepare_model_inputs_hook=_prepare_model_inputs_hook ) if timing != "": diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 1bc76acf..83bf58e3 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -10,12 +10,36 @@ import json import random +def _prepare_model_inputs_hook(i, input_ids, kwargs): + """To produce like graphs during pre-fill, we mark the prefill batch x seq as static, but relax this for decode for the seq""" + if i == 0: + # we always want prefill to be static to produce same-like graph + torch._dynamo.mark_static(input_ids, 0) + torch._dynamo.mark_static(input_ids, 1) + torch._dynamo.mark_static(kwargs["mask"], 0) + torch._dynamo.mark_static(kwargs["mask"], 1) + torch._dynamo.mark_static(kwargs["mask"], 2) + torch._dynamo.mark_static(kwargs["position_ids"], 0) + torch._dynamo.mark_static(kwargs["position_ids"], 1) + else: + # we always want the decode to be dynamic on sequence + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(kwargs["mask"], 1) + torch._dynamo.mark_dynamic(kwargs["mask"], 2) + + for layer in kwargs["past_key_value_states"]: + for tensor in layer: + torch._dynamo.mark_static(tensor, 0) + + return input_ids, kwargs + + def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, **padding_kwargs): from torch_sendnn import torch_sendnn dprint("AIU warmup") pt_compile_model_time = time.time() extra_kwargs = {**padding_kwargs, "only_last_token": True} - generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs) + generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs, prepare_model_inputs_hook=_prepare_model_inputs_hook) pt_compile_model_time = time.time() - pt_compile_model_time dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index d1e03c33..9a76cee9 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -5,7 +5,7 @@ import itertools import torch from aiu_fms_testing_utils.testing.validation import extract_validation_information, LogitsExtractorHook, GoldenTokenHook, capture_level_1_metrics, filter_failed_level_1_cases, load_validation_information, validate_level_0, top_k_loss_calculator -from aiu_fms_testing_utils.utils import warmup_model, sample_sharegpt_requests, ids_for_prompt +from aiu_fms_testing_utils.utils import warmup_model, sample_sharegpt_requests, ids_for_prompt, _prepare_model_inputs_hook from aiu_fms_testing_utils.utils.aiu_setup import dprint import os @@ -275,5 +275,73 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): else: print("passed validation level 0") +def test_warmup_multiple_shapes(): + shapes = [ + (1, 64, 8), + (2, 64, 8), + (1, 128, 24), + ] + reference_model = get_model( + architecture="hf_configured", + variant=GRANITE_3p2_8B_INSTRUCT, + device_type="cpu", + fused_weights=False, + nlayers=3 + ) + + model = get_model( + architecture="hf_configured", + variant=GRANITE_3p2_8B_INSTRUCT, + device_type="cpu", + fused_weights=False, + nlayers=3 + ) + + reference_model.load_state_dict(model.state_dict()) + + model.eval() + reference_model.eval() + + torch.set_grad_enabled(False) + model.compile(backend="sendnn_decoder") + for bs, sl, mnt in shapes: + # prepare input_ids + prompt_list = [] + for i in range(bs): + prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long)) + + input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) + # warmup aiu model + warmup_model(model, input_ids, mnt, **padding_kwargs) + + # perform 3 inference, making sure ordering does not affect things + for _ in range(3): + shapes.reverse() + for bs, sl, mnt in shapes: + prompt_list = [] + for i in range(bs): + prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long)) + input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) + + cpu_validation_info = extract_validation_information( + reference_model, + input_ids, + mnt, + LogitsExtractorHook(), + attn_algorithm="math", + **padding_kwargs + ) + + aiu_validation_info = extract_validation_information( + model, + input_ids, + mnt, + None, + only_last_token=True, + **padding_kwargs + ) + + failed_responses = validate_level_0(aiu_validation_info.get_info("tokens"), cpu_validation_info.get_info("tokens")) + assert len(failed_responses) == 0 From 4648b241d134575b4b508f3ce4b1a6fb7c06387e Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Wed, 2 Apr 2025 00:23:05 +0000 Subject: [PATCH 2/2] added encoder warmup test; added logits validation for decoder warmup Signed-off-by: Joshua Rosenkranz --- aiu_fms_testing_utils/utils/__init__.py | 7 +- tests/models/test_decoders.py | 115 ++++++++++++++++-------- tests/models/test_encoders.py | 76 +++++++++++++++- 3 files changed, 156 insertions(+), 42 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 83bf58e3..f5307418 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -23,9 +23,10 @@ def _prepare_model_inputs_hook(i, input_ids, kwargs): torch._dynamo.mark_static(kwargs["position_ids"], 1) else: # we always want the decode to be dynamic on sequence - torch._dynamo.mark_dynamic(input_ids, 1) - torch._dynamo.mark_dynamic(kwargs["mask"], 1) - torch._dynamo.mark_dynamic(kwargs["mask"], 2) + if torch._dynamo.config.dynamic_shapes: + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(kwargs["mask"], 1) + torch._dynamo.mark_dynamic(kwargs["mask"], 2) for layer in kwargs["past_key_value_states"]: for tensor in layer: diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 9a76cee9..c155a564 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -275,13 +275,16 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): else: print("passed validation level 0") -def test_warmup_multiple_shapes(): +@pytest.mark.parametrize("use_static_shapes", [False, True], ids=["dynamic_shapes", "static_shapes"]) +def test_warmup_multiple_shapes(use_static_shapes): shapes = [ - (1, 64, 8), - (2, 64, 8), - (1, 128, 24), + (1, 64, 4), + (2, 64, 4), + (1, 128, 8), ] + tokenizer = tokenizers.get_tokenizer(GRANITE_3p2_8B_INSTRUCT) + reference_model = get_model( architecture="hf_configured", variant=GRANITE_3p2_8B_INSTRUCT, @@ -305,43 +308,79 @@ def test_warmup_multiple_shapes(): torch.set_grad_enabled(False) model.compile(backend="sendnn_decoder") - for bs, sl, mnt in shapes: - # prepare input_ids - prompt_list = [] - for i in range(bs): - prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long)) - - input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) - # warmup aiu model - warmup_model(model, input_ids, mnt, **padding_kwargs) + + if use_static_shapes: + ctx = torch._dynamo.config.patch( + assume_static_by_default=True, + dynamic_shapes=False, + automatic_dynamic_shapes=False, + cache_size_limit=1000, + ) + else: + ctx = torch._dynamo.config.patch( + dynamic_shapes=True, + ) + + # metric calculator based on the cross-entropy and mean diff for each decode step + def _metric_calculator(r: torch.Tensor, t: torch.Tensor): + cross_entropy = torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32)) + diff = torch.mean(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32)) + return (cross_entropy, diff) - # perform 3 inference, making sure ordering does not affect things - for _ in range(3): - shapes.reverse() + with ctx: + for bs, sl, mnt in shapes: + # prepare input_ids prompt_list = [] for i in range(bs): prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long)) - input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) - - cpu_validation_info = extract_validation_information( - reference_model, - input_ids, - mnt, - LogitsExtractorHook(), - attn_algorithm="math", - **padding_kwargs - ) - aiu_validation_info = extract_validation_information( - model, - input_ids, - mnt, - None, - only_last_token=True, - **padding_kwargs - ) - - failed_responses = validate_level_0(aiu_validation_info.get_info("tokens"), cpu_validation_info.get_info("tokens")) - - assert len(failed_responses) == 0 + input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) + # warmup aiu model + warmup_model(model, input_ids, mnt, **padding_kwargs) + + # perform 3 inference, making sure ordering does not affect things + for _ in range(3): + shapes.reverse() + for bs, sl, mnt in shapes: + prompt_list = [] + for i in range(bs): + prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long)) + input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl) + + cpu_validation_info = extract_validation_information( + reference_model, + input_ids, + mnt, + LogitsExtractorHook(), + attn_algorithm="math", + **padding_kwargs + ) + cpu_static_tokens = cpu_validation_info.get_info("tokens") + eos_indexes = __find_eos_index(cpu_static_tokens, tokenizer.eos_token_id, sl, mnt) + + aiu_validation_info = extract_validation_information( + model, + input_ids, + mnt, + GoldenTokenHook(cpu_static_tokens), + only_last_token=True, + **padding_kwargs + ) + + level_1_metrics = capture_level_1_metrics( + cpu_validation_info.get_info("logits"), + aiu_validation_info.get_info("logits"), + top_k_loss_calculator(20, _metric_calculator) + ) + + # only consider those metrics captured prior to the eos + level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) + + ce_threshold, diff_thresholds = fail_thresholds.get(GRANITE_3p2_8B_INSTRUCT, default_metrics_threshold) + + # get all failed responses for each metric + ce_fail_responses = filter_failed_level_1_cases(level_1_metrics, lambda m: m[0] >= ce_threshold) + diff_fail_responses = filter_failed_level_1_cases(level_1_metrics, lambda m: m[1] <= diff_thresholds[0] or m[1] >= diff_thresholds[1]) + assert len(ce_fail_responses) == 0 + assert len(diff_fail_responses) == 0 diff --git a/tests/models/test_encoders.py b/tests/models/test_encoders.py index fb045e93..36d62e9f 100644 --- a/tests/models/test_encoders.py +++ b/tests/models/test_encoders.py @@ -104,4 +104,78 @@ def test_common_shapes(model_path, batch_size, seq_length): cpu_msp = ModelSignatureParams(validation_model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) # FIXME: Compute GPU atol/rtol - compare_model_signatures(cpu_msp, aiu_msp, atol=0.1, rtol=.05) \ No newline at end of file + compare_model_signatures(cpu_msp, aiu_msp, atol=0.1, rtol=.05) + +def test_warmup_multiple_shapes(): + os.environ["COMPILATION_MODE"] = "offline" + + if "HF_HOME" not in os.environ: + os.environ["HF_HOME"] = "/tmp/models/hf_cache" + + model_path = "deepset/roberta-base-squad2" + tokenizer = tokenizers.get_tokenizer(model_path) + + if os.path.exists(model_path): + model_path_kwargs = {"model_path": model_path} + else: + model_path_kwargs = {"variant": model_path} + + + shapes = [ + (1, 64), + (2, 64), + (1, 128), + ] + + # prepare the AIU model + model = get_model( + architecture="hf_pretrained", + device_type="cpu", + fused_weights=False, + **model_path_kwargs + ) + + model.eval() + torch.set_grad_enabled(False) + model.compile(backend="sendnn") + + # prepare the cpu model + reference_model = get_model( + architecture="hf_pretrained", + device_type="cpu", + data_type=torch.float32, + fused_weights=False, + **model_path_kwargs + ) + + # encoders should be using static shapes + with torch._dynamo.config.patch( + assume_static_by_default=True, + dynamic_shapes=False, + automatic_dynamic_shapes=False, + cache_size_limit=1000, + ): + + for batch_size, seq_length in shapes: + + # prepare input_ids + input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) + + # warmup model + logits_getter_fn = lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1) + aiu_msp = ModelSignatureParams(model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) + get_signature(aiu_msp.model, aiu_msp.params, aiu_msp.inp, aiu_msp.other_params, aiu_msp.logits_getter_fn) + + for _ in range(3): + shapes.reverse() + + for batch_size, seq_length in shapes: + logits_getter_fn = lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1) + aiu_msp = ModelSignatureParams(model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) + get_signature(aiu_msp.model, aiu_msp.params, aiu_msp.inp, aiu_msp.other_params, aiu_msp.logits_getter_fn) + + cpu_msp = ModelSignatureParams(reference_model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs) + compare_model_signatures(cpu_msp, aiu_msp, atol=0.1, rtol=.05) + + + \ No newline at end of file