Skip to content

Commit 1af2202

Browse files
committed
Options for Stagger model loading for low memory systems
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
1 parent 73bcb07 commit 1af2202

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

scripts/inference.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
import random
99
import time
10+
import math
1011

1112
# Third Party
1213
from aiu_fms_testing_utils.utils import aiu_setup
@@ -212,6 +213,18 @@
212213
default=0,
213214
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
214215
)
216+
parser.add_argument(
217+
"--stagger_load",
218+
type=int,
219+
default=0,
220+
help="Stagger model loading to avoid OOM issues on the host"
221+
)
222+
parser.add_argument(
223+
"--stagger_update_lazyhandle",
224+
type=int,
225+
default=0,
226+
help="Stagger update_lazyhandle to avoid OOM issues on the host"
227+
)
215228
args = parser.parse_args()
216229

217230
if args.quantization == "gptq":
@@ -432,6 +445,13 @@ def select_int8_module(
432445
dprint(f"data_type={default_dtype}")
433446
dprint("="*60 + "\n")
434447

448+
if args.stagger_load > 0:
449+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
450+
if rank < (_set+1)*args.stagger_load:
451+
break
452+
torch.distributed.barrier()
453+
dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})")
454+
435455
model = get_model(
436456
args.architecture,
437457
args.variant,
@@ -461,6 +481,13 @@ def select_int8_module(
461481
loading_model_time = time.time() - loading_model_time
462482
dprint(f"loading complete, took {loading_model_time:.3f}s")
463483

484+
if args.stagger_load > 0:
485+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
486+
if rank >= (_set+1)*args.stagger_load:
487+
continue
488+
torch.distributed.barrier()
489+
dprint(f"Stagger Model Load: All Complete")
490+
464491
if args.compile:
465492
dprint("compiling model")
466493
if is_aiu_backend:
@@ -693,13 +720,27 @@ def infer(use_cache, do_sample, warmup):
693720
pt_compile_model_time = time.time() - pt_compile_model_time
694721
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
695722

723+
if args.stagger_update_lazyhandle > 0:
724+
for _set in range( math.ceil(world_size / float(args.stagger_update_lazyhandle)) ):
725+
if rank < (_set+1)*args.stagger_update_lazyhandle:
726+
break
727+
torch.distributed.barrier()
728+
dprint(f"Stagger update_lazyhandle: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_update_lazyhandle))})")
729+
696730
if is_aiu_backend:
697731
dprint("executing update_lazyhandle and compiling for AIU")
698732
update_lh_time = time.time()
699733
torch_sendnn.update_lazyhandle()
700734
update_lh_time = time.time() - update_lh_time
701735
dprint(f"update_lazyhandle complete, took {update_lh_time:.3f}s")
702736

737+
if args.stagger_update_lazyhandle > 0:
738+
for _set in range( math.ceil(world_size / float(args.stagger_update_lazyhandle)) ):
739+
if rank >= (_set+1)*args.stagger_update_lazyhandle:
740+
continue
741+
torch.distributed.barrier()
742+
dprint(f"Stagger update_lazyhandle: All Complete")
743+
703744
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
704745
aiu_warmup_time = time.time()
705746
for sample, cache in itertools.product(do_sample, use_cache):

0 commit comments

Comments
 (0)