From c6a3d0036515a47cb3bbb813dccf2828709e3018 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Fri, 11 Jul 2025 22:04:50 +0000 Subject: [PATCH 1/2] Add support for FP8 scaling per sequence Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/utils/paged.py | 29 ++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 78ce4a81..3e2dacc4 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -142,8 +142,8 @@ def generate( from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor kwargs["past_key_value_states"] = [ ( - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), - ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False), + ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False), + ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False), ) for _ in range(model.config.nlayers) ] @@ -218,6 +218,8 @@ def generate( outputs_list = [] current_kv_cache = kwargs["past_key_value_states"] + if "fp8" in kwargs["attn_name"]: + current_kv_scales = [(t1._scale, t2._scale) for t1, t2 in kwargs["past_key_value_states"]] for seq_i in range(input_ids.size(0)): input_ids_i = input_ids[seq_i].unsqueeze(0) slot_mapping_i = kwargs["slot_mapping"][seq_i].unsqueeze(0) @@ -237,6 +239,12 @@ def generate( torch._dynamo.mark_dynamic(mask_i, 2) torch._dynamo.mark_dynamic(mask_i, 3) + # FP8 per-sentence scale handling + if "fp8" in kwargs["attn_name"]: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + t1._scale = current_kv_scales[layer_idx][seq_i] + t2._scale = current_kv_scales[layer_idx][seq_i] + only_last_token = kwargs.get("only_last_token", False) output, current_kv_cache = model( @@ -251,10 +259,19 @@ def generate( ) # TODO: Figure out how to do this cleanly - if "fp8" in kwargs["attn_name"] and seq_i != input_ids.size(0) - 1: - for layer_cache in current_kv_cache: - layer_cache[0]._scaled = False - layer_cache[1]._scaled = False + if "fp8" in kwargs["attn_name"]: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + current_kv_scales[layer_idx][0][seq_i] = t1._scale + current_kv_scales[layer_idx][1][seq_i] = t2._scale + + if seq_i != input_ids.size(0) - 1: + for layer_cache in current_kv_cache: + layer_cache[0]._scaled = False + layer_cache[1]._scaled = False + else: + for layer_idx, (t1, t2) in enumerate(current_kv_cache): + t1._scale = current_kv_scales[layer_idx][0] + t2._scale = current_kv_scales[layer_idx][1] outputs_list.append(output[0].squeeze(0)) From 5cd32a7b8a0c4b281cd61ac6a878b26a4a13035d Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 16 Jul 2025 15:01:59 +0000 Subject: [PATCH 2/2] fix scale dimensions Signed-off-by: Antoni Viros i Martin --- aiu_fms_testing_utils/utils/paged.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 3e2dacc4..500dfb45 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -242,8 +242,8 @@ def generate( # FP8 per-sentence scale handling if "fp8" in kwargs["attn_name"]: for layer_idx, (t1, t2) in enumerate(current_kv_cache): - t1._scale = current_kv_scales[layer_idx][seq_i] - t2._scale = current_kv_scales[layer_idx][seq_i] + t1._scale = current_kv_scales[layer_idx][0][seq_i].reshape(-1) + t2._scale = current_kv_scales[layer_idx][1][seq_i].reshape(-1) only_last_token = kwargs.get("only_last_token", False) @@ -288,6 +288,10 @@ def generate( torch._dynamo.mark_dynamic(kwargs["position_ids"], 0) torch._dynamo.mark_dynamic(kwargs["current_tkv_mask"], 0) torch._dynamo.mark_dynamic(kwargs["left_padded_prompt_mask"], 0) + if "fp8" in kwargs["attn_name"]: + for k_cache, v_cache in kwargs["past_key_value_states"]: + torch._dynamo.mark_dynamic(k_cache._scale, 0) + torch._dynamo.mark_dynamic(v_cache._scale, 0) # seq torch._dynamo.mark_static(input_ids, 1) # always 1