Skip to content

Options for Stagger model loading for low memory systems #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,52 @@
import time
from fms.utils.tokenizers import BaseTokenizer
from fms.utils.generation import generate
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
from typing import Optional, List, Tuple
import os
import requests
import json
import random
import math
import contextlib

def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
@contextlib.contextmanager
def stagger_region(limit: int):
"""
Limit the number of concurrent processes into this region of code.
Processes yield from this function when they are allowed to enter the region of code.
Processes return from this function when all of the processes have completed the region of code.

:param limit: Number of concurrent processes allowed in the code region if > 0.
"""
if limit > 0 and limit != world_size:
for _set in range( math.ceil(world_size / float(limit)) ):
if rank < (_set+1)*limit:
break
torch.distributed.barrier()
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})")
yield {}
if limit > 0 and limit != world_size:
for _set in range( math.ceil(world_size / float(limit)) ):
if rank >= (_set+1)*limit:
continue
torch.distributed.barrier()
dprint(f"Stagger: All Complete")

def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, stagger_update_lazyhandle = 0, **padding_kwargs):
import torch_sendnn
dprint("AIU warmup")
pt_compile_model_time = time.time()
extra_kwargs = {**padding_kwargs, "only_last_token": True}
max_new_tokens_warmup = max_new_tokens
if compile_dynamic_sendnn:
max_new_tokens_warmup = 2
with torch_sendnn.warmup_mode():
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
pt_compile_model_time = time.time() - pt_compile_model_time
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

with stagger_region(stagger_update_lazyhandle) as _s:
pt_compile_model_time = time.time()
with torch_sendnn.warmup_mode():
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
pt_compile_model_time = time.time() - pt_compile_model_time
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

def ids_for_prompt(prompt, tokenizer):
tokens = tokenizer.tokenize(prompt)
Expand Down
85 changes: 56 additions & 29 deletions scripts/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard
import argparse
import datetime
from functools import partial
import itertools
import json
Expand All @@ -8,9 +9,10 @@
import random
import time
import contextlib
import math

# Third Party
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
import numpy as np
import torch
Expand Down Expand Up @@ -218,6 +220,24 @@
default=0,
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
)
parser.add_argument(
"--stagger_load",
type=int,
default=0,
help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes"
)
parser.add_argument(
"--stagger_update_lazyhandle",
type=int,
default=0,
help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes"
)
parser.add_argument(
"--dist_timeout",
type=int,
default=0,
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group"
)
args = parser.parse_args()

if args.quantization == "gptq":
Expand Down Expand Up @@ -260,7 +280,13 @@
is_aiu_backend = "aiu" in args.device_type

if args.distributed:
dist.init_process_group()
if args.dist_timeout > 0:
# Default timeout:
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
else:
dist.init_process_group()
# Fix until PT 2.3
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
Expand Down Expand Up @@ -438,34 +464,35 @@ def select_int8_module(
dprint(f"data_type={default_dtype}")
dprint("="*60 + "\n")

model = get_model(
args.architecture,
args.variant,
model_path=args.model_path,
device_type="cpu" if is_aiu_backend else args.device_type,
data_type=default_dtype,
source=args.model_source,
distributed_strategy=distr_param,
group=dist.group.WORLD,
linear_config=linear_config,
fused_weights=fused_weights,
)

if args.quantization in ["gptq", "int8"]:
if rank == 0 and args.verbose > 0:
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
with stagger_region(args.stagger_load) as _s:
model = get_model(
args.architecture,
args.variant,
model_path=args.model_path,
device_type="cpu" if is_aiu_backend else args.device_type,
data_type=default_dtype,
source=args.model_source,
distributed_strategy=distr_param,
group=dist.group.WORLD,
linear_config=linear_config,
fused_weights=fused_weights,
)

if args.quantization in ["gptq", "int8"]:
if rank == 0 and args.verbose > 0:
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers()))
dprint("="*60 + "\n")
if args.architecture == "llama":
dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.")
dprint(model)
dprint("="*60 + "\n")
if args.architecture == "llama":
dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.")
dprint(model)
dprint("="*60 + "\n")

tokenizer = tokenizers.get_tokenizer(args.tokenizer)
model.eval()
torch.set_grad_enabled(False)
loading_model_time = time.time() - loading_model_time
dprint(f"loading complete, took {loading_model_time:.3f}s")
tokenizer = tokenizers.get_tokenizer(args.tokenizer)
model.eval()
torch.set_grad_enabled(False)
loading_model_time = time.time() - loading_model_time
dprint(f"loading complete, took {loading_model_time:.3f}s")

if args.compile:
dprint("compiling model")
Expand Down Expand Up @@ -696,7 +723,7 @@ def infer(use_cache, do_sample, warmup):
dprint(f"compilation warmup")
pt_compile_model_time = time.time()
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs)
aiu_warmup_time = time.time()
for sample, cache in itertools.product(do_sample, use_cache):
infer(cache, sample, True)
Expand Down