Skip to content

Commit b2ea5e3

Browse files
committed
add --compile_dynamic_sendnn option
Signed-off-by: Jiamin Ni <jiamin.ni@ibm.com>
1 parent 0e1d496 commit b2ea5e3

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
import json
1111
import random
1212

13-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, **padding_kwargs):
13+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
1414
from torch_sendnn import torch_sendnn
1515
dprint("AIU warmup")
1616
pt_compile_model_time = time.time()
1717
extra_kwargs = {**padding_kwargs, "only_last_token": True}
18-
max_new_tokens_warmup = 2
19-
is_dynamic_value = os.getenv("TORCH_SENDNN_DYNAMIC")
20-
if is_dynamic_value is None or is_dynamic_value.lower() in {"0", "false"}:
21-
max_new_tokens_warmup = max_new_tokens
18+
max_new_tokens_warmup = max_new_tokens
19+
if compile_dynamic_sendnn:
20+
max_new_tokens_warmup = 2
2221
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
2322
pt_compile_model_time = time.time() - pt_compile_model_time
2423
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

scripts/inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@
129129
action="store_true",
130130
help="Use dynamic shapes with torch.compile",
131131
)
132+
parser.add_argument(
133+
"--compile_dynamic_sendnn",
134+
action="store_true",
135+
help="Use dynamic shapes with aiu compile",
136+
)
132137
parser.add_argument(
133138
"--deterministic",
134139
action="store_true",
@@ -464,7 +469,7 @@ def select_int8_module(
464469
if args.compile:
465470
dprint("compiling model")
466471
if is_aiu_backend:
467-
model.compile(backend="sendnn_decoder")
472+
model.compile(backend="sendnn_decoder", options={'sendnn.dynamic': args.compile_dynamic_sendnn})
468473
else:
469474
# compiling can make first inference pass slow
470475
model.compile(mode=args.compile_mode, backend=args.compile_backend)
@@ -686,7 +691,7 @@ def infer(use_cache, do_sample, warmup):
686691
] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
687692

688693
if args.compile:
689-
warmup_model(model, ids, args.max_new_tokens, **extra_generation_kwargs)
694+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
690695

691696
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
692697
aiu_warmup_time = time.time()

scripts/validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@
106106
action="store_true",
107107
help="Use dynamic shapes with torch.compile",
108108
)
109+
parser.add_argument(
110+
"--compile_dynamic_sendnn",
111+
action="store_true",
112+
help="Use dynamic shapes with aiu compile",
113+
)
109114
parser.add_argument(
110115
"--deterministic",
111116
action="store_true",
@@ -680,7 +685,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
680685
**padding_kwargs
681686
)
682687

683-
warmup_model(model, ids, args.max_new_tokens, **padding_kwargs)
688+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **padding_kwargs)
684689

685690
### AIU generation loop
686691
static_tokens = validation_info.get_info("tokens")

0 commit comments

Comments
 (0)