Skip to content

Commit c5218b1

Browse files
committed
Options for Stagger model loading for low memory systems
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
1 parent b882a64 commit c5218b1

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import time
44
from fms.utils.tokenizers import BaseTokenizer
55
from fms.utils.generation import generate
6-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
6+
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
77
from typing import Optional, List, Tuple
88
import os
99
import requests
1010
import json
1111
import random
12+
import math
1213

13-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
14+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, stagger_update_lazyhandle = 0, **padding_kwargs):
1415
from torch_sendnn import torch_sendnn
1516
dprint("AIU warmup")
1617
pt_compile_model_time = time.time()
@@ -22,12 +23,26 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
2223
pt_compile_model_time = time.time() - pt_compile_model_time
2324
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
2425

26+
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
27+
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
28+
if rank < (_set+1)*stagger_update_lazyhandle:
29+
break
30+
torch.distributed.barrier()
31+
dprint(f"Stagger update_lazyhandle: Begin (Set: {_set+1} of {math.ceil(world_size / float(stagger_update_lazyhandle))})")
32+
2533
dprint("executing update_lazyhandle and performing validation")
2634
update_lh_time = time.time()
2735
torch_sendnn.update_lazyhandle()
2836
update_lh_time = time.time() - update_lh_time
2937
dprint(f"update_lazyhandle complete, took {update_lh_time:.3f}s")
3038

39+
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
40+
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
41+
if rank >= (_set+1)*stagger_update_lazyhandle:
42+
continue
43+
torch.distributed.barrier()
44+
dprint(f"Stagger update_lazyhandle: All Complete")
45+
3146
def ids_for_prompt(prompt, tokenizer):
3247
tokens = tokenizer.tokenize(prompt)
3348
ids = tokenizer.convert_tokens_to_ids(tokens)

scripts/inference.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
import random
99
import time
10+
import math
1011

1112
# Third Party
1213
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
@@ -217,6 +218,18 @@
217218
default=0,
218219
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
219220
)
221+
parser.add_argument(
222+
"--stagger_load",
223+
type=int,
224+
default=0,
225+
help="Stagger model loading to avoid OOM issues on the host"
226+
)
227+
parser.add_argument(
228+
"--stagger_update_lazyhandle",
229+
type=int,
230+
default=0,
231+
help="Stagger update_lazyhandle to avoid OOM issues on the host"
232+
)
220233
args = parser.parse_args()
221234

222235
if args.quantization == "gptq":
@@ -437,6 +450,13 @@ def select_int8_module(
437450
dprint(f"data_type={default_dtype}")
438451
dprint("="*60 + "\n")
439452

453+
if args.stagger_load > 0 and args.stagger_load != world_size:
454+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
455+
if rank < (_set+1)*args.stagger_load:
456+
break
457+
torch.distributed.barrier()
458+
dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})")
459+
440460
model = get_model(
441461
args.architecture,
442462
args.variant,
@@ -466,6 +486,13 @@ def select_int8_module(
466486
loading_model_time = time.time() - loading_model_time
467487
dprint(f"loading complete, took {loading_model_time:.3f}s")
468488

489+
if args.stagger_load > 0 and args.stagger_load != world_size:
490+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
491+
if rank >= (_set+1)*args.stagger_load:
492+
continue
493+
torch.distributed.barrier()
494+
dprint(f"Stagger Model Load: All Complete")
495+
469496
if args.compile:
470497
dprint("compiling model")
471498
if is_aiu_backend:
@@ -691,7 +718,7 @@ def infer(use_cache, do_sample, warmup):
691718
] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
692719

693720
if args.compile:
694-
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
721+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs)
695722

696723
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
697724
aiu_warmup_time = time.time()

0 commit comments

Comments
 (0)