Skip to content

Commit 14dfb4a

Browse files
committed
Consolidate stagger code into two support funcitons: stagger_enter, stagger_leave
Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
1 parent 4bacd2c commit 14dfb4a

File tree

2 files changed

+21
-25
lines changed

2 files changed

+21
-25
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
import random
1212
import math
1313

14+
def stagger_enter(limit: int):
15+
if limit > 0 and limit != world_size:
16+
for _set in range( math.ceil(world_size / float(limit)) ):
17+
if rank < (_set+1)*limit:
18+
break
19+
torch.distributed.barrier()
20+
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})")
21+
22+
def stagger_leave(limit: int):
23+
if limit > 0 and limit != world_size:
24+
for _set in range( math.ceil(world_size / float(limit)) ):
25+
if rank >= (_set+1)*limit:
26+
continue
27+
torch.distributed.barrier()
28+
dprint(f"Stagger: All Complete")
29+
1430
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, stagger_update_lazyhandle = 0, **padding_kwargs):
1531
import torch_sendnn
1632
dprint("AIU warmup")
@@ -19,25 +35,15 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
1935
if compile_dynamic_sendnn:
2036
max_new_tokens_warmup = 2
2137

22-
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
23-
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
24-
if rank < (_set+1)*stagger_update_lazyhandle:
25-
break
26-
torch.distributed.barrier()
27-
dprint(f"Stagger update_lazyhandle: Begin (Set: {_set+1} of {math.ceil(world_size / float(stagger_update_lazyhandle))})")
38+
stagger_enter(stagger_update_lazyhandle)
2839

2940
pt_compile_model_time = time.time()
3041
with torch_sendnn.warmup_mode():
3142
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
3243
pt_compile_model_time = time.time() - pt_compile_model_time
3344
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
3445

35-
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
36-
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
37-
if rank >= (_set+1)*stagger_update_lazyhandle:
38-
continue
39-
torch.distributed.barrier()
40-
dprint(f"Stagger update_lazyhandle: All Complete")
46+
stagger_leave(stagger_update_lazyhandle)
4147

4248
def ids_for_prompt(prompt, tokenizer):
4349
tokens = tokenizer.tokenize(prompt)

scripts/inference.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313

1414
# Third Party
15-
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
15+
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_enter, stagger_leave
1616
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1717
import numpy as np
1818
import torch
@@ -464,12 +464,7 @@ def select_int8_module(
464464
dprint(f"data_type={default_dtype}")
465465
dprint("="*60 + "\n")
466466

467-
if args.stagger_load > 0 and args.stagger_load != world_size:
468-
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
469-
if rank < (_set+1)*args.stagger_load:
470-
break
471-
torch.distributed.barrier()
472-
dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})")
467+
stagger_enter(args.stagger_load)
473468

474469
model = get_model(
475470
args.architecture,
@@ -500,12 +495,7 @@ def select_int8_module(
500495
loading_model_time = time.time() - loading_model_time
501496
dprint(f"loading complete, took {loading_model_time:.3f}s")
502497

503-
if args.stagger_load > 0 and args.stagger_load != world_size:
504-
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
505-
if rank >= (_set+1)*args.stagger_load:
506-
continue
507-
torch.distributed.barrier()
508-
dprint(f"Stagger Model Load: All Complete")
498+
stagger_leave(args.stagger_load)
509499

510500
if args.compile:
511501
dprint("compiling model")

0 commit comments

Comments
 (0)