Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def generate(
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
kwargs["past_key_value_states"] = [
(
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the dtype torch.float32 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 scales are stored in fp32 always

ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False),
)
for _ in range(model.config.nlayers)
]
Expand Down Expand Up @@ -216,6 +216,8 @@ def generate(

outputs_list = []
current_kv_cache = kwargs["past_key_value_states"]
if "fp8" in kwargs["attn_name"]:
current_kv_scales = [(t1._scale, t2._scale) for t1, t2 in kwargs["past_key_value_states"]]
for seq_i in range(input_ids.size(0)):
input_ids_i = input_ids[seq_i].unsqueeze(0)
slot_mapping_i = kwargs["slot_mapping"][seq_i].unsqueeze(0)
Expand All @@ -235,6 +237,12 @@ def generate(
torch._dynamo.mark_dynamic(mask_i, 2)
torch._dynamo.mark_dynamic(mask_i, 3)

# FP8 per-sentence scale handling
if "fp8" in kwargs["attn_name"]:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
t1._scale = current_kv_scales[layer_idx][0][seq_i].reshape(-1)
t2._scale = current_kv_scales[layer_idx][1][seq_i].reshape(-1)

only_last_token = kwargs.get("only_last_token", False)

output, current_kv_cache = model(
Expand All @@ -249,10 +257,19 @@ def generate(
)

# TODO: Figure out how to do this cleanly
if "fp8" in kwargs["attn_name"] and seq_i != input_ids.size(0) - 1:
for layer_cache in current_kv_cache:
layer_cache[0]._scaled = False
layer_cache[1]._scaled = False
if "fp8" in kwargs["attn_name"]:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
current_kv_scales[layer_idx][0][seq_i] = t1._scale
current_kv_scales[layer_idx][1][seq_i] = t2._scale

if seq_i != input_ids.size(0) - 1:
for layer_cache in current_kv_cache:
layer_cache[0]._scaled = False
layer_cache[1]._scaled = False
else:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
t1._scale = current_kv_scales[layer_idx][0]
t2._scale = current_kv_scales[layer_idx][1]

outputs_list.append(output[0].squeeze(0))

Expand All @@ -269,6 +286,10 @@ def generate(
torch._dynamo.mark_dynamic(kwargs["position_ids"], 0)
torch._dynamo.mark_dynamic(kwargs["current_tkv_mask"], 0)
torch._dynamo.mark_dynamic(kwargs["left_padded_prompt_mask"], 0)
if "fp8" in kwargs["attn_name"]:
for k_cache, v_cache in kwargs["past_key_value_states"]:
torch._dynamo.mark_dynamic(k_cache._scale, 0)
torch._dynamo.mark_dynamic(v_cache._scale, 0)

# seq
torch._dynamo.mark_static(input_ids, 1) # always 1
Expand Down