From f05ee448391cde57cb925158f788c3e0c9014cf2 Mon Sep 17 00:00:00 2001 From: Joshua Hursey Date: Tue, 6 May 2025 15:08:20 -0400 Subject: [PATCH 1/2] Options for Stagger model loading for low memory systems * `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host * `--dist_timeout` : (default: either `10` for NCCL or `30` for others set by PyTorch) torch distributed timeout in minutes Signed-off-by: Joshua Hursey --- aiu_fms_testing_utils/utils/__init__.py | 42 +++++++++++++++++++++++-- scripts/inference.py | 36 +++++++++++++++++++-- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 12c9c01..b3483c2 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -3,26 +3,62 @@ import time from fms.utils.tokenizers import BaseTokenizer from fms.utils.generation import generate -from aiu_fms_testing_utils.utils.aiu_setup import dprint +from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size from typing import Optional, List, Tuple import os import requests import json import random +import math -def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs): +def stagger_enter(limit: int): + """ + Limit the number of concurrent processes into this section of code. + Processes return from this funciton when they are allowed to enter the section of code. + Must be paired with another call to stagger_leave() when exiting the section of code. + + :param limit: Number of concurrent processes allowed in the code section if > 0. + """ + if limit > 0 and limit != world_size: + for _set in range( math.ceil(world_size / float(limit)) ): + if rank < (_set+1)*limit: + break + torch.distributed.barrier() + dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})") + +def stagger_leave(limit: int): + """ + Leave a section code where the number of concurrent processes is limited. + Processes return from this funciton when all of the processes have completed the section of code. + Must be paired with another call to stagger_enter() when entering the section of code. + + :param limit: Number of concurrent processes allowed in the code section if > 0. + """ + if limit > 0 and limit != world_size: + for _set in range( math.ceil(world_size / float(limit)) ): + if rank >= (_set+1)*limit: + continue + torch.distributed.barrier() + dprint(f"Stagger: All Complete") + +def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, stagger_update_lazyhandle = 0, **padding_kwargs): import torch_sendnn dprint("AIU warmup") - pt_compile_model_time = time.time() extra_kwargs = {**padding_kwargs, "only_last_token": True} max_new_tokens_warmup = max_new_tokens if compile_dynamic_sendnn: max_new_tokens_warmup = 2 + + stagger_enter(stagger_update_lazyhandle) + + pt_compile_model_time = time.time() with torch_sendnn.warmup_mode(): generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs) pt_compile_model_time = time.time() - pt_compile_model_time dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") + stagger_leave(stagger_update_lazyhandle) + def ids_for_prompt(prompt, tokenizer): tokens = tokenizer.tokenize(prompt) ids = tokenizer.convert_tokens_to_ids(tokens) diff --git a/scripts/inference.py b/scripts/inference.py index 9316b8b..77bd7b7 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -1,5 +1,6 @@ # Standard import argparse +import datetime from functools import partial import itertools import json @@ -8,9 +9,10 @@ import random import time import contextlib +import math # Third Party -from aiu_fms_testing_utils.utils import aiu_setup, warmup_model +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_enter, stagger_leave from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size import numpy as np import torch @@ -218,6 +220,24 @@ default=0, help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)" ) +parser.add_argument( + "--stagger_load", + type=int, + default=0, + help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes" +) +parser.add_argument( + "--stagger_update_lazyhandle", + type=int, + default=0, + help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes" +) +parser.add_argument( + "--dist_timeout", + type=int, + default=0, + help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group" +) args = parser.parse_args() if args.quantization == "gptq": @@ -260,7 +280,13 @@ is_aiu_backend = "aiu" in args.device_type if args.distributed: - dist.init_process_group() + if args.dist_timeout > 0: + # Default timeout: + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) + dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes") + else: + dist.init_process_group() # Fix until PT 2.3 torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size()) @@ -438,6 +464,8 @@ def select_int8_module( dprint(f"data_type={default_dtype}") dprint("="*60 + "\n") +stagger_enter(args.stagger_load) + model = get_model( args.architecture, args.variant, @@ -467,6 +495,8 @@ def select_int8_module( loading_model_time = time.time() - loading_model_time dprint(f"loading complete, took {loading_model_time:.3f}s") +stagger_leave(args.stagger_load) + if args.compile: dprint("compiling model") if is_aiu_backend: @@ -696,7 +726,7 @@ def infer(use_cache, do_sample, warmup): dprint(f"compilation warmup") pt_compile_model_time = time.time() if args.device_type == "aiu": # only run warmup for AIU, no need for senulator - warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) + warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs) aiu_warmup_time = time.time() for sample, cache in itertools.product(do_sample, use_cache): infer(cache, sample, True) From b7c22e088a88a776930ce04a4562ee0edde27cdd Mon Sep 17 00:00:00 2001 From: Joshua Hursey Date: Thu, 12 Jun 2025 16:43:33 -0400 Subject: [PATCH 2/2] Convert the stagger enter/leave into a proper contextlib function Signed-off-by: Joshua Hursey --- aiu_fms_testing_utils/utils/__init__.py | 37 ++++++--------- scripts/inference.py | 61 ++++++++++++------------- 2 files changed, 43 insertions(+), 55 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index b3483c2..d090a2a 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -10,14 +10,16 @@ import json import random import math +import contextlib -def stagger_enter(limit: int): +@contextlib.contextmanager +def stagger_region(limit: int): """ - Limit the number of concurrent processes into this section of code. - Processes return from this funciton when they are allowed to enter the section of code. - Must be paired with another call to stagger_leave() when exiting the section of code. + Limit the number of concurrent processes into this region of code. + Processes yield from this function when they are allowed to enter the region of code. + Processes return from this function when all of the processes have completed the region of code. - :param limit: Number of concurrent processes allowed in the code section if > 0. + :param limit: Number of concurrent processes allowed in the code region if > 0. """ if limit > 0 and limit != world_size: for _set in range( math.ceil(world_size / float(limit)) ): @@ -25,15 +27,7 @@ def stagger_enter(limit: int): break torch.distributed.barrier() dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})") - -def stagger_leave(limit: int): - """ - Leave a section code where the number of concurrent processes is limited. - Processes return from this funciton when all of the processes have completed the section of code. - Must be paired with another call to stagger_enter() when entering the section of code. - - :param limit: Number of concurrent processes allowed in the code section if > 0. - """ + yield {} if limit > 0 and limit != world_size: for _set in range( math.ceil(world_size / float(limit)) ): if rank >= (_set+1)*limit: @@ -49,15 +43,12 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, if compile_dynamic_sendnn: max_new_tokens_warmup = 2 - stagger_enter(stagger_update_lazyhandle) - - pt_compile_model_time = time.time() - with torch_sendnn.warmup_mode(): - generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs) - pt_compile_model_time = time.time() - pt_compile_model_time - dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") - - stagger_leave(stagger_update_lazyhandle) + with stagger_region(stagger_update_lazyhandle) as _s: + pt_compile_model_time = time.time() + with torch_sendnn.warmup_mode(): + generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs) + pt_compile_model_time = time.time() - pt_compile_model_time + dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") def ids_for_prompt(prompt, tokenizer): tokens = tokenizer.tokenize(prompt) diff --git a/scripts/inference.py b/scripts/inference.py index 77bd7b7..b5aba8b 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -12,7 +12,7 @@ import math # Third Party -from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_enter, stagger_leave +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size import numpy as np import torch @@ -464,38 +464,35 @@ def select_int8_module( dprint(f"data_type={default_dtype}") dprint("="*60 + "\n") -stagger_enter(args.stagger_load) - -model = get_model( - args.architecture, - args.variant, - model_path=args.model_path, - device_type="cpu" if is_aiu_backend else args.device_type, - data_type=default_dtype, - source=args.model_source, - distributed_strategy=distr_param, - group=dist.group.WORLD, - linear_config=linear_config, - fused_weights=fused_weights, -) - -if args.quantization in ["gptq", "int8"]: - if rank == 0 and args.verbose > 0: - dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) - dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) +with stagger_region(args.stagger_load) as _s: + model = get_model( + args.architecture, + args.variant, + model_path=args.model_path, + device_type="cpu" if is_aiu_backend else args.device_type, + data_type=default_dtype, + source=args.model_source, + distributed_strategy=distr_param, + group=dist.group.WORLD, + linear_config=linear_config, + fused_weights=fused_weights, + ) + + if args.quantization in ["gptq", "int8"]: + if rank == 0 and args.verbose > 0: + dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) + dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) + dprint("="*60 + "\n") + if args.architecture == "llama": + dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") + dprint(model) dprint("="*60 + "\n") - if args.architecture == "llama": - dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") - dprint(model) - dprint("="*60 + "\n") - -tokenizer = tokenizers.get_tokenizer(args.tokenizer) -model.eval() -torch.set_grad_enabled(False) -loading_model_time = time.time() - loading_model_time -dprint(f"loading complete, took {loading_model_time:.3f}s") - -stagger_leave(args.stagger_load) + + tokenizer = tokenizers.get_tokenizer(args.tokenizer) + model.eval() + torch.set_grad_enabled(False) + loading_model_time = time.time() - loading_model_time + dprint(f"loading complete, took {loading_model_time:.3f}s") if args.compile: dprint("compiling model")