Skip to content

Commit b882a64

Browse files
authored
Merge pull request #36 from foundation-model-stack/jni/dev
set warmup iterations to 2 if dynamic
2 parents 73bcb07 + b2ea5e3 commit b882a64

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
import json
1111
import random
1212

13-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, **padding_kwargs):
13+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
1414
from torch_sendnn import torch_sendnn
1515
dprint("AIU warmup")
1616
pt_compile_model_time = time.time()
1717
extra_kwargs = {**padding_kwargs, "only_last_token": True}
18-
generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
18+
max_new_tokens_warmup = max_new_tokens
19+
if compile_dynamic_sendnn:
20+
max_new_tokens_warmup = 2
21+
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
1922
pt_compile_model_time = time.time() - pt_compile_model_time
2023
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
2124

scripts/inference.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010

1111
# Third Party
12-
from aiu_fms_testing_utils.utils import aiu_setup
12+
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
1313
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1414
import numpy as np
1515
import torch
@@ -129,6 +129,11 @@
129129
action="store_true",
130130
help="Use dynamic shapes with torch.compile",
131131
)
132+
parser.add_argument(
133+
"--compile_dynamic_sendnn",
134+
action="store_true",
135+
help="Use dynamic shapes with aiu compile",
136+
)
132137
parser.add_argument(
133138
"--deterministic",
134139
action="store_true",
@@ -464,7 +469,7 @@ def select_int8_module(
464469
if args.compile:
465470
dprint("compiling model")
466471
if is_aiu_backend:
467-
model.compile(backend="sendnn_decoder")
472+
model.compile(backend="sendnn_decoder", options={'sendnn.dynamic': args.compile_dynamic_sendnn})
468473
else:
469474
# compiling can make first inference pass slow
470475
model.compile(mode=args.compile_mode, backend=args.compile_backend)
@@ -686,19 +691,7 @@ def infer(use_cache, do_sample, warmup):
686691
] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
687692

688693
if args.compile:
689-
dprint(f"compilation warmup")
690-
pt_compile_model_time = time.time()
691-
for sample, cache in itertools.product(do_sample, use_cache):
692-
infer(cache, sample, True)
693-
pt_compile_model_time = time.time() - pt_compile_model_time
694-
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
695-
696-
if is_aiu_backend:
697-
dprint("executing update_lazyhandle and compiling for AIU")
698-
update_lh_time = time.time()
699-
torch_sendnn.update_lazyhandle()
700-
update_lh_time = time.time() - update_lh_time
701-
dprint(f"update_lazyhandle complete, took {update_lh_time:.3f}s")
694+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
702695

703696
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
704697
aiu_warmup_time = time.time()

scripts/validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@
106106
action="store_true",
107107
help="Use dynamic shapes with torch.compile",
108108
)
109+
parser.add_argument(
110+
"--compile_dynamic_sendnn",
111+
action="store_true",
112+
help="Use dynamic shapes with aiu compile",
113+
)
109114
parser.add_argument(
110115
"--deterministic",
111116
action="store_true",
@@ -680,7 +685,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
680685
**padding_kwargs
681686
)
682687

683-
warmup_model(model, ids, args.max_new_tokens, **padding_kwargs)
688+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **padding_kwargs)
684689

685690
### AIU generation loop
686691
static_tokens = validation_info.get_info("tokens")

0 commit comments

Comments
 (0)