Skip to content

Commit f60d30f

Browse files
committed
Convert the stagger enter/leave into a proper contextlib function
Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
1 parent 130a407 commit f60d30f

File tree

2 files changed

+43
-55
lines changed

2 files changed

+43
-55
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,24 @@
1010
import json
1111
import random
1212
import math
13+
import contextlib
1314

14-
def stagger_enter(limit: int):
15+
@contextlib.contextmanager
16+
def stagger_region(limit: int):
1517
"""
16-
Limit the number of concurrent processes into this section of code.
17-
Processes return from this funciton when they are allowed to enter the section of code.
18-
Must be paired with another call to stagger_leave() when exiting the section of code.
18+
Limit the number of concurrent processes into this region of code.
19+
Processes yield from this function when they are allowed to enter the region of code.
20+
Processes return from this function when all of the processes have completed the region of code.
1921
20-
:param limit: Number of concurrent processes allowed in the code section if > 0.
22+
:param limit: Number of concurrent processes allowed in the code region if > 0.
2123
"""
2224
if limit > 0 and limit != world_size:
2325
for _set in range( math.ceil(world_size / float(limit)) ):
2426
if rank < (_set+1)*limit:
2527
break
2628
torch.distributed.barrier()
2729
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})")
28-
29-
def stagger_leave(limit: int):
30-
"""
31-
Leave a section code where the number of concurrent processes is limited.
32-
Processes return from this funciton when all of the processes have completed the section of code.
33-
Must be paired with another call to stagger_enter() when entering the section of code.
34-
35-
:param limit: Number of concurrent processes allowed in the code section if > 0.
36-
"""
30+
yield {}
3731
if limit > 0 and limit != world_size:
3832
for _set in range( math.ceil(world_size / float(limit)) ):
3933
if rank >= (_set+1)*limit:
@@ -49,15 +43,12 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
4943
if compile_dynamic_sendnn:
5044
max_new_tokens_warmup = 2
5145

52-
stagger_enter(stagger_update_lazyhandle)
53-
54-
pt_compile_model_time = time.time()
55-
with torch_sendnn.warmup_mode():
56-
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
57-
pt_compile_model_time = time.time() - pt_compile_model_time
58-
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
59-
60-
stagger_leave(stagger_update_lazyhandle)
46+
with stagger_region(stagger_update_lazyhandle) as _s:
47+
pt_compile_model_time = time.time()
48+
with torch_sendnn.warmup_mode():
49+
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
50+
pt_compile_model_time = time.time() - pt_compile_model_time
51+
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
6152

6253
def ids_for_prompt(prompt, tokenizer):
6354
tokens = tokenizer.tokenize(prompt)

scripts/inference.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313

1414
# Third Party
15-
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_enter, stagger_leave
15+
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region
1616
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1717
import numpy as np
1818
import torch
@@ -464,38 +464,35 @@ def select_int8_module(
464464
dprint(f"data_type={default_dtype}")
465465
dprint("="*60 + "\n")
466466

467-
stagger_enter(args.stagger_load)
468-
469-
model = get_model(
470-
args.architecture,
471-
args.variant,
472-
model_path=args.model_path,
473-
device_type="cpu" if is_aiu_backend else args.device_type,
474-
data_type=default_dtype,
475-
source=args.model_source,
476-
distributed_strategy=distr_param,
477-
group=dist.group.WORLD,
478-
linear_config=linear_config,
479-
fused_weights=fused_weights,
480-
)
481-
482-
if args.quantization in ["gptq", "int8"]:
483-
if rank == 0 and args.verbose > 0:
484-
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
485-
dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
467+
with stagger_region(args.stagger_load) as _s:
468+
model = get_model(
469+
args.architecture,
470+
args.variant,
471+
model_path=args.model_path,
472+
device_type="cpu" if is_aiu_backend else args.device_type,
473+
data_type=default_dtype,
474+
source=args.model_source,
475+
distributed_strategy=distr_param,
476+
group=dist.group.WORLD,
477+
linear_config=linear_config,
478+
fused_weights=fused_weights,
479+
)
480+
481+
if args.quantization in ["gptq", "int8"]:
482+
if rank == 0 and args.verbose > 0:
483+
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
484+
dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
485+
dprint("="*60 + "\n")
486+
if args.architecture == "llama":
487+
dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.")
488+
dprint(model)
486489
dprint("="*60 + "\n")
487-
if args.architecture == "llama":
488-
dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.")
489-
dprint(model)
490-
dprint("="*60 + "\n")
491-
492-
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
493-
model.eval()
494-
torch.set_grad_enabled(False)
495-
loading_model_time = time.time() - loading_model_time
496-
dprint(f"loading complete, took {loading_model_time:.3f}s")
497-
498-
stagger_leave(args.stagger_load)
490+
491+
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
492+
model.eval()
493+
torch.set_grad_enabled(False)
494+
loading_model_time = time.time() - loading_model_time
495+
dprint(f"loading complete, took {loading_model_time:.3f}s")
499496

500497
if args.compile:
501498
dprint("compiling model")

0 commit comments

Comments
 (0)