Skip to content

Fp8 scaling per-sequence #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
result = generate(
model,
input_ids,
max_seq_len=input_ids.shape[1] + max_new_tokens,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False,
Expand Down
1 change: 1 addition & 0 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def warmup_model(
generate(
model,
_warmup_input_ids,
max_seq_len=_warmup_input_ids.shape[1] + max_new_tokens,
max_new_tokens=_max_new_tokens,
do_sample=False,
use_cache=use_cache,
Expand Down
30 changes: 24 additions & 6 deletions aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
def generate(
model: Union[Callable, torch.nn.Module],
input_ids: torch.Tensor,
max_seq_len: int = 4096,
max_new_tokens: int = 256,
temperature: float = 1.0,
top_k: int = 10,
Expand Down Expand Up @@ -141,8 +142,8 @@ def generate(
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
kwargs["past_key_value_states"] = [
(
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False),
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32), False),
)
for _ in range(model.config.nlayers)
]
Expand Down Expand Up @@ -217,6 +218,8 @@ def generate(

outputs_list = []
current_kv_cache = kwargs["past_key_value_states"]
if "fp8" in kwargs["attn_name"]:
current_kv_scales = [(t1._scale, t2._scale) for t1, t2 in kwargs["past_key_value_states"]]
for seq_i in range(input_ids.size(0)):
input_ids_i = input_ids[seq_i].unsqueeze(0)
slot_mapping_i = kwargs["slot_mapping"][seq_i].unsqueeze(0)
Expand All @@ -236,6 +239,12 @@ def generate(
torch._dynamo.mark_dynamic(mask_i, 2)
torch._dynamo.mark_dynamic(mask_i, 3)

# FP8 per-sentence scale handling
if "fp8" in kwargs["attn_name"]:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
t1._scale = current_kv_scales[layer_idx][seq_i]
t2._scale = current_kv_scales[layer_idx][seq_i]

only_last_token = kwargs.get("only_last_token", False)

output, current_kv_cache = model(
Expand All @@ -250,10 +259,19 @@ def generate(
)

# TODO: Figure out how to do this cleanly
if "fp8" in kwargs["attn_name"] and seq_i != input_ids.size(0) - 1:
for layer_cache in current_kv_cache:
layer_cache[0]._scaled = False
layer_cache[1]._scaled = False
if "fp8" in kwargs["attn_name"]:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
current_kv_scales[layer_idx][0][seq_i] = t1._scale
current_kv_scales[layer_idx][1][seq_i] = t2._scale

if seq_i != input_ids.size(0) - 1:
for layer_cache in current_kv_cache:
layer_cache[0]._scaled = False
layer_cache[1]._scaled = False
else:
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
t1._scale = current_kv_scales[layer_idx][0]
t2._scale = current_kv_scales[layer_idx][1]

outputs_list.append(output[0].squeeze(0))

Expand Down
1 change: 1 addition & 0 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def infer(use_cache, do_sample, warmup):
result = generate(
model,
ids,
max_seq_len=ids.shape[1] + args.max_new_tokens,
max_new_tokens=args.max_new_tokens,
use_cache=use_cache,
do_sample=do_sample,
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/test_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_paged_equivalence():
result = generate(
_model_mock,
ids,
max_seq_len=ids.shape[1] + 5,
max_new_tokens=5,
do_sample=False,
use_cache=True,
Expand All @@ -38,6 +39,7 @@ def test_paged_equivalence():
result_paged = paged_generate(
_model_mock,
ids,
max_seq_len=ids.shape[1] + 5,
max_new_tokens=5,
do_sample=False,
use_cache=True,
Expand Down