Skip to content

added fixes for handling multiple shape warmup #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from fms.utils.generation import generate
from aiu_fms_testing_utils.utils import ids_for_prompt
from aiu_fms_testing_utils.utils import ids_for_prompt, _prepare_model_inputs_hook
from aiu_fms_testing_utils.utils.aiu_setup import dprint
import os

Expand Down Expand Up @@ -206,6 +206,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
timing=timing,
contiguous_cache=True,
extra_kwargs=extra_generation_kwargs,
prepare_model_inputs_hook=_prepare_model_inputs_hook
)

if timing != "":
Expand Down
26 changes: 25 additions & 1 deletion aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,36 @@
import json
import random

def _prepare_model_inputs_hook(i, input_ids, kwargs):
"""To produce like graphs during pre-fill, we mark the prefill batch x seq as static, but relax this for decode for the seq"""
if i == 0:
# we always want prefill to be static to produce same-like graph
torch._dynamo.mark_static(input_ids, 0)
torch._dynamo.mark_static(input_ids, 1)
torch._dynamo.mark_static(kwargs["mask"], 0)
torch._dynamo.mark_static(kwargs["mask"], 1)
torch._dynamo.mark_static(kwargs["mask"], 2)
torch._dynamo.mark_static(kwargs["position_ids"], 0)
torch._dynamo.mark_static(kwargs["position_ids"], 1)
Comment on lines +17 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to mark all the sequence dimensions as static or is just the batch dimensions enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably is enough, however I marked everything as static to ensure we get a static prefill -- I believe symbolic ints can cause changes in the graph in prefill that we may not want to introduce.

else:
# we always want the decode to be dynamic on sequence
torch._dynamo.mark_dynamic(input_ids, 1)
torch._dynamo.mark_dynamic(kwargs["mask"], 1)
torch._dynamo.mark_dynamic(kwargs["mask"], 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably only need to mark the dim 2 as dynamic here


for layer in kwargs["past_key_value_states"]:
for tensor in layer:
torch._dynamo.mark_static(tensor, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could move the mark kv cache sequence dimension as dynamic code here as well


return input_ids, kwargs


def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, **padding_kwargs):
from torch_sendnn import torch_sendnn
dprint("AIU warmup")
pt_compile_model_time = time.time()
extra_kwargs = {**padding_kwargs, "only_last_token": True}
generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs, prepare_model_inputs_hook=_prepare_model_inputs_hook)
pt_compile_model_time = time.time() - pt_compile_model_time
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

Expand Down
70 changes: 69 additions & 1 deletion tests/models/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import torch
from aiu_fms_testing_utils.testing.validation import extract_validation_information, LogitsExtractorHook, GoldenTokenHook, capture_level_1_metrics, filter_failed_level_1_cases, load_validation_information, validate_level_0, top_k_loss_calculator
from aiu_fms_testing_utils.utils import warmup_model, sample_sharegpt_requests, ids_for_prompt
from aiu_fms_testing_utils.utils import warmup_model, sample_sharegpt_requests, ids_for_prompt, _prepare_model_inputs_hook
from aiu_fms_testing_utils.utils.aiu_setup import dprint
import os

Expand Down Expand Up @@ -275,5 +275,73 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
else:
print("passed validation level 0")

def test_warmup_multiple_shapes():
shapes = [
(1, 64, 8),
(2, 64, 8),
(1, 128, 24),
]

reference_model = get_model(
architecture="hf_configured",
variant=GRANITE_3p2_8B_INSTRUCT,
device_type="cpu",
fused_weights=False,
nlayers=3
)

model = get_model(
architecture="hf_configured",
variant=GRANITE_3p2_8B_INSTRUCT,
device_type="cpu",
fused_weights=False,
nlayers=3
)

reference_model.load_state_dict(model.state_dict())

model.eval()
reference_model.eval()

torch.set_grad_enabled(False)
model.compile(backend="sendnn_decoder")
for bs, sl, mnt in shapes:
# prepare input_ids
prompt_list = []
for i in range(bs):
prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long))

input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl)
# warmup aiu model
warmup_model(model, input_ids, mnt, **padding_kwargs)

# perform 3 inference, making sure ordering does not affect things
for _ in range(3):
shapes.reverse()
for bs, sl, mnt in shapes:
prompt_list = []
for i in range(bs):
prompt_list.append(torch.randint(0, model.config.src_vocab_size, (sl - 2 * i,), dtype=torch.long))
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=sl)

cpu_validation_info = extract_validation_information(
reference_model,
input_ids,
mnt,
LogitsExtractorHook(),
attn_algorithm="math",
**padding_kwargs
)

aiu_validation_info = extract_validation_information(
model,
input_ids,
mnt,
None,
only_last_token=True,
**padding_kwargs
)

failed_responses = validate_level_0(aiu_validation_info.get_info("tokens"), cpu_validation_info.get_info("tokens"))

assert len(failed_responses) == 0