|
7 | 7 | from pathlib import Path
|
8 | 8 | import random
|
9 | 9 | import time
|
| 10 | +import math |
10 | 11 |
|
11 | 12 | # Third Party
|
12 | 13 | from aiu_fms_testing_utils.utils import aiu_setup
|
|
212 | 213 | default=0,
|
213 | 214 | help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
|
214 | 215 | )
|
| 216 | +parser.add_argument( |
| 217 | + "--stagger_load", |
| 218 | + type=int, |
| 219 | + default=0, |
| 220 | + help="Stagger model loading to avoid OOM issues on the host" |
| 221 | +) |
| 222 | +parser.add_argument( |
| 223 | + "--stagger_update_lazyhandle", |
| 224 | + type=int, |
| 225 | + default=0, |
| 226 | + help="Stagger update_lazyhandle to avoid OOM issues on the host" |
| 227 | +) |
215 | 228 | args = parser.parse_args()
|
216 | 229 |
|
217 | 230 | if args.quantization == "gptq":
|
@@ -432,6 +445,13 @@ def select_int8_module(
|
432 | 445 | dprint(f"data_type={default_dtype}")
|
433 | 446 | dprint("="*60 + "\n")
|
434 | 447 |
|
| 448 | +if args.stagger_load > 0: |
| 449 | + for _set in range( math.ceil(world_size / float(args.stagger_load)) ): |
| 450 | + if rank < (_set+1)*args.stagger_load: |
| 451 | + break |
| 452 | + torch.distributed.barrier() |
| 453 | + dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})") |
| 454 | + |
435 | 455 | model = get_model(
|
436 | 456 | args.architecture,
|
437 | 457 | args.variant,
|
@@ -461,6 +481,13 @@ def select_int8_module(
|
461 | 481 | loading_model_time = time.time() - loading_model_time
|
462 | 482 | dprint(f"loading complete, took {loading_model_time:.3f}s")
|
463 | 483 |
|
| 484 | +if args.stagger_load > 0: |
| 485 | + for _set in range( math.ceil(world_size / float(args.stagger_load)) ): |
| 486 | + if rank >= (_set+1)*args.stagger_load: |
| 487 | + continue |
| 488 | + torch.distributed.barrier() |
| 489 | + dprint(f"Stagger Model Load: All Complete") |
| 490 | + |
464 | 491 | if args.compile:
|
465 | 492 | dprint("compiling model")
|
466 | 493 | if is_aiu_backend:
|
@@ -693,13 +720,27 @@ def infer(use_cache, do_sample, warmup):
|
693 | 720 | pt_compile_model_time = time.time() - pt_compile_model_time
|
694 | 721 | dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
|
695 | 722 |
|
| 723 | + if args.stagger_update_lazyhandle > 0: |
| 724 | + for _set in range( math.ceil(world_size / float(args.stagger_update_lazyhandle)) ): |
| 725 | + if rank < (_set+1)*args.stagger_update_lazyhandle: |
| 726 | + break |
| 727 | + torch.distributed.barrier() |
| 728 | + dprint(f"Stagger update_lazyhandle: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_update_lazyhandle))})") |
| 729 | + |
696 | 730 | if is_aiu_backend:
|
697 | 731 | dprint("executing update_lazyhandle and compiling for AIU")
|
698 | 732 | update_lh_time = time.time()
|
699 | 733 | torch_sendnn.update_lazyhandle()
|
700 | 734 | update_lh_time = time.time() - update_lh_time
|
701 | 735 | dprint(f"update_lazyhandle complete, took {update_lh_time:.3f}s")
|
702 | 736 |
|
| 737 | + if args.stagger_update_lazyhandle > 0: |
| 738 | + for _set in range( math.ceil(world_size / float(args.stagger_update_lazyhandle)) ): |
| 739 | + if rank >= (_set+1)*args.stagger_update_lazyhandle: |
| 740 | + continue |
| 741 | + torch.distributed.barrier() |
| 742 | + dprint(f"Stagger update_lazyhandle: All Complete") |
| 743 | + |
703 | 744 | if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
|
704 | 745 | aiu_warmup_time = time.time()
|
705 | 746 | for sample, cache in itertools.product(do_sample, use_cache):
|
|
0 commit comments