From dcc95c02ec52956ca68e90f66785f886d76fda14 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 3 Jul 2025 20:55:31 +0000 Subject: [PATCH 1/5] Fix max length of generation being cut to 4096 Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/utils/paged.py | 1 + scripts/inference.py | 1 + 2 files changed, 2 insertions(+) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 771eb01a..78ce4a81 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -27,6 +27,7 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): def generate( model: Union[Callable, torch.nn.Module], input_ids: torch.Tensor, + max_seq_len: int = 4096, max_new_tokens: int = 256, temperature: float = 1.0, top_k: int = 10, diff --git a/scripts/inference.py b/scripts/inference.py index 28c62aa8..bd9ae937 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -717,6 +717,7 @@ def infer(use_cache, do_sample, warmup): result = generate( model, ids, + max_seq_len=ids.shape[1] + args.max_new_tokens, max_new_tokens=args.max_new_tokens, use_cache=use_cache, do_sample=do_sample, From 5e0a4a3b56ee0e6f2fae7e816eb40d8391643fed Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 3 Jul 2025 21:01:09 +0000 Subject: [PATCH 2/5] Correct max_seq_len across aftu Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/testing/validation.py | 1 + aiu_fms_testing_utils/utils/__init__.py | 1 + tests/utils/test_paged.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 69b6a9b4..ba8e718b 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -208,6 +208,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat result = generate( model, input_ids, + max_seq_len=input_ids.shape[1] + max_new_tokens, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False, diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 92997067..222a3b63 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -53,6 +53,7 @@ def warmup_model( generate( model, _warmup_input_ids, + max_seq_len=_warmup_input_ids.shape[1] + max_new_tokens, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, diff --git a/tests/utils/test_paged.py b/tests/utils/test_paged.py index 519042af..106a29d9 100644 --- a/tests/utils/test_paged.py +++ b/tests/utils/test_paged.py @@ -29,6 +29,7 @@ def test_paged_equivalence(): result = generate( _model_mock, ids, + max_seq_len=ids.shape[1] + 5, max_new_tokens=5, do_sample=False, use_cache=True, @@ -38,6 +39,7 @@ def test_paged_equivalence(): result_paged = paged_generate( _model_mock, ids, + max_seq_len=ids.shape[1] + 5, max_new_tokens=5, do_sample=False, use_cache=True, From bf06dba191f1c2ce34fb57442a26b263ee2ea664 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Fri, 11 Jul 2025 22:15:06 +0000 Subject: [PATCH 3/5] fix signatures Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/utils/__init__.py | 2 +- aiu_fms_testing_utils/utils/paged.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 222a3b63..9cc4dc37 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -30,6 +30,7 @@ def warmup_model( # TODO: Add a unified generation dependent on attn_type from fms.utils.generation import generate attention_specific_kwargs["contiguous_cache"] = True + attention_specific_kwargs["max_seq_len"] = input_ids.shape[1] + max_new_tokens dprint("AIU warmup") pt_compile_model_time = time.time() @@ -53,7 +54,6 @@ def warmup_model( generate( model, _warmup_input_ids, - max_seq_len=_warmup_input_ids.shape[1] + max_new_tokens, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 78ce4a81..5ad16ca6 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -27,7 +27,6 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): def generate( model: Union[Callable, torch.nn.Module], input_ids: torch.Tensor, - max_seq_len: int = 4096, max_new_tokens: int = 256, temperature: float = 1.0, top_k: int = 10, @@ -55,7 +54,6 @@ def generate( model: A function or nn.Module that takes a batch of input_ids and returns logits input_ids: a rectangular tensor of input_ids (batch x seq) - max_seq_len: the sequence length of the model max_new_tokens: max tokens to generate temperature: temperature of softmax when sampling top_k: only search among top k tokens From a6abeffea0cdf2b28ee4cbb93cb30aff2d62c9ae Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Fri, 11 Jul 2025 22:17:32 +0000 Subject: [PATCH 4/5] more fixes Signed-off-by: Antoni Viros i Martin --- scripts/inference.py | 2 +- tests/utils/test_paged.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index bd9ae937..1a8f6a51 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -713,11 +713,11 @@ def infer(use_cache, do_sample, warmup): attention_specific_kwargs = {} if attn_name == "sdpa_causal": attention_specific_kwargs["contiguous_cache"] = True + attention_specific_kwargs["max_seq_len"] = ids.shape[1] + args.max_new_tokens result = generate( model, ids, - max_seq_len=ids.shape[1] + args.max_new_tokens, max_new_tokens=args.max_new_tokens, use_cache=use_cache, do_sample=do_sample, diff --git a/tests/utils/test_paged.py b/tests/utils/test_paged.py index 106a29d9..ef79e597 100644 --- a/tests/utils/test_paged.py +++ b/tests/utils/test_paged.py @@ -39,7 +39,6 @@ def test_paged_equivalence(): result_paged = paged_generate( _model_mock, ids, - max_seq_len=ids.shape[1] + 5, max_new_tokens=5, do_sample=False, use_cache=True, From b83650f0b44c6a1cc32c1e66851e70de09059c73 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Fri, 11 Jul 2025 22:19:48 +0000 Subject: [PATCH 5/5] One last fix? Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/testing/validation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index ba8e718b..65c1d1bb 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -188,7 +188,6 @@ def load_validation_information(validation_path, validation_files_type, batch_si return ValidationInfo(validation_info) def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **extra_kwargs): - max_seq_len = model.config.max_expected_seq_len attention_specific_kwargs = {} if "paged" in extra_kwargs["attn_name"]: from aiu_fms_testing_utils.utils.paged import generate @@ -196,7 +195,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat # TODO: Add a unified generation dependent on attn_type from fms.utils.generation import generate attention_specific_kwargs["contiguous_cache"] = True - attention_specific_kwargs["max_seq_len"] = max_seq_len + attention_specific_kwargs["max_seq_len"] = input_ids.shape[1] + max_new_tokens # Add only_last_token optimization extra_generation_kwargs = {**extra_kwargs} @@ -208,7 +207,6 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat result = generate( model, input_ids, - max_seq_len=input_ids.shape[1] + max_new_tokens, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False,