|
9 | 9 | import time
|
10 | 10 |
|
11 | 11 | # Third Party
|
12 |
| -from aiu_fms_testing_utils.utils import aiu_setup |
| 12 | +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model |
13 | 13 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
|
14 | 14 | import numpy as np
|
15 | 15 | import torch
|
|
129 | 129 | action="store_true",
|
130 | 130 | help="Use dynamic shapes with torch.compile",
|
131 | 131 | )
|
| 132 | +parser.add_argument( |
| 133 | + "--compile_dynamic_sendnn", |
| 134 | + action="store_true", |
| 135 | + help="Use dynamic shapes with aiu compile", |
| 136 | +) |
132 | 137 | parser.add_argument(
|
133 | 138 | "--deterministic",
|
134 | 139 | action="store_true",
|
@@ -464,7 +469,7 @@ def select_int8_module(
|
464 | 469 | if args.compile:
|
465 | 470 | dprint("compiling model")
|
466 | 471 | if is_aiu_backend:
|
467 |
| - model.compile(backend="sendnn_decoder") |
| 472 | + model.compile(backend="sendnn_decoder", options={'sendnn.dynamic': args.compile_dynamic_sendnn}) |
468 | 473 | else:
|
469 | 474 | # compiling can make first inference pass slow
|
470 | 475 | model.compile(mode=args.compile_mode, backend=args.compile_backend)
|
@@ -686,19 +691,7 @@ def infer(use_cache, do_sample, warmup):
|
686 | 691 | ] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
|
687 | 692 |
|
688 | 693 | if args.compile:
|
689 |
| - dprint(f"compilation warmup") |
690 |
| - pt_compile_model_time = time.time() |
691 |
| - for sample, cache in itertools.product(do_sample, use_cache): |
692 |
| - infer(cache, sample, True) |
693 |
| - pt_compile_model_time = time.time() - pt_compile_model_time |
694 |
| - dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") |
695 |
| - |
696 |
| - if is_aiu_backend: |
697 |
| - dprint("executing update_lazyhandle and compiling for AIU") |
698 |
| - update_lh_time = time.time() |
699 |
| - torch_sendnn.update_lazyhandle() |
700 |
| - update_lh_time = time.time() - update_lh_time |
701 |
| - dprint(f"update_lazyhandle complete, took {update_lh_time:.3f}s") |
| 694 | + warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) |
702 | 695 |
|
703 | 696 | if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
|
704 | 697 | aiu_warmup_time = time.time()
|
|
0 commit comments