diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 69b6a9b..ba8e718 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -208,6 +208,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat result = generate( model, input_ids, + max_seq_len=input_ids.shape[1] + max_new_tokens, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False, diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 9299706..222a3b6 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -53,6 +53,7 @@ def warmup_model( generate( model, _warmup_input_ids, + max_seq_len=_warmup_input_ids.shape[1] + max_new_tokens, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 771eb01..3e2dacc 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -27,6 +27,7 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs): def generate( model: Union[Callable, torch.nn.Module], input_ids: torch.Tensor, + max_seq_len: int = 4096, max_new_tokens: int = 256, temperature: float = 1.0, top_k: int = 10, @@ -141,8 +142,8 @@ def generate( from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor kwargs["past_key_value_states"] = [ ( - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), + ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False), + ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False), ) for _ in range(model.config.nlayers) ] @@ -217,6 +218,8 @@ def generate( outputs_list = [] current_kv_cache = kwargs["past_key_value_states"] + if "fp8" in kwargs["attn_name"]: + current_kv_scales = [(t1._scale, t2._scale) for t1, t2 in kwargs["past_key_value_states"]] for seq_i in range(input_ids.size(0)): input_ids_i = input_ids[seq_i].unsqueeze(0) slot_mapping_i = kwargs["slot_mapping"][seq_i].unsqueeze(0) @@ -236,6 +239,12 @@ def generate( torch._dynamo.mark_dynamic(mask_i, 2) torch._dynamo.mark_dynamic(mask_i, 3) + # FP8 per-sentence scale handling + if "fp8" in kwargs["attn_name"]: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + t1._scale = current_kv_scales[layer_idx][seq_i] + t2._scale = current_kv_scales[layer_idx][seq_i] + only_last_token = kwargs.get("only_last_token", False) output, current_kv_cache = model( @@ -250,10 +259,19 @@ def generate( ) # TODO: Figure out how to do this cleanly - if "fp8" in kwargs["attn_name"] and seq_i != input_ids.size(0) - 1: - for layer_cache in current_kv_cache: - layer_cache[0]._scaled = False - layer_cache[1]._scaled = False + if "fp8" in kwargs["attn_name"]: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + current_kv_scales[layer_idx][0][seq_i] = t1._scale + current_kv_scales[layer_idx][1][seq_i] = t2._scale + + if seq_i != input_ids.size(0) - 1: + for layer_cache in current_kv_cache: + layer_cache[0]._scaled = False + layer_cache[1]._scaled = False + else: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + t1._scale = current_kv_scales[layer_idx][0] + t2._scale = current_kv_scales[layer_idx][1] outputs_list.append(output[0].squeeze(0)) diff --git a/scripts/inference.py b/scripts/inference.py index 28c62aa..bd9ae93 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -717,6 +717,7 @@ def infer(use_cache, do_sample, warmup): result = generate( model, ids, + max_seq_len=ids.shape[1] + args.max_new_tokens, max_new_tokens=args.max_new_tokens, use_cache=use_cache, do_sample=do_sample, diff --git a/tests/utils/test_paged.py b/tests/utils/test_paged.py index 519042a..106a29d 100644 --- a/tests/utils/test_paged.py +++ b/tests/utils/test_paged.py @@ -29,6 +29,7 @@ def test_paged_equivalence(): result = generate( _model_mock, ids, + max_seq_len=ids.shape[1] + 5, max_new_tokens=5, do_sample=False, use_cache=True, @@ -38,6 +39,7 @@ def test_paged_equivalence(): result_paged = paged_generate( _model_mock, ids, + max_seq_len=ids.shape[1] + 5, max_new_tokens=5, do_sample=False, use_cache=True,