Skip to content

Refactor inference.py for LLM and RoBERTa support #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

andrea-fasoli
Copy link
Contributor

@andrea-fasoli andrea-fasoli commented Apr 23, 2025

This PR implements a substantial refactoring of inference.py which becomes the single entry point for LLMs and RoBERTa models. Support covers non-quantized, GPTQ W4A16, and INT8 models.

inference.py code has been streamlined. It is now structured into the following sections:

args_parsing            define script arguments across all model configurations
aiu_setup               set up AIU environment variables
model_setup             define model dtype, device, and distributed strategy
quantization_setup      import FMS-MO addons and define linear_config for FMS
direct_quantization     quantize a non-quantized model to INT8 (WIP)
decoders                run token generation task with LLMs
encoders                run QA or MLM task with RoBERTa

Extensive code validation is needed prior merging.

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
@andrea-fasoli
Copy link
Contributor Author

cc for review: @ani300 @JRosenkranz

@@ -57,6 +59,8 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
if local_rank < 0:
local_rank = rank

# FIXME: local_size not in use ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's no longer needed? if you can't find any reference to it feel free to delete

Comment on lines +82 to +106
_target_cache_size = max(
int(args.max_new_tokens * 2),
int(args.min_pad_length * 2.5),
int(args.fixed_prompt_length * 2.5),
)
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
dprint(
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
f"{args.max_new_tokens}"
)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
dprint(
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only needed if compile_dynamic is disabled, can we gate it?

Comment on lines +114 to +117
os.environ.setdefault("SENCORES", "32")
os.environ.setdefault("SENCORELETS", "2")
os.environ.setdefault("DATA_PREC", "fp16")
os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of these are already set by default on the e2e_stable image, can we check and remove the ones we don't need anymore?

os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1")
os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true")

os.environ.setdefault("COMPILATION_MODE", "offline_decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is only needed for decoder models, for roberta it will probably make it not work

print("must set AIU_WORLD_RANK_0")
exit()
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
os.environ.setdefault("FLEX_DEVICE", "VFIO")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think VFIO is now PF or VF, depending on the AIU setup, we can set PF by default as most cards are running in PF mode

action="store_true",
help=(
"If set to True, this will unfuse any fused weight modules that "
"support the unfuse_weights method"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be upgraded to "If set to True, this will unfuse any fused weights in the model" as the way it's done doesn't involve "unfuse_weights" anymore

"--seed",
type=int,
default=81072,
help="Run seed (only needed if eval dataset is shuffled)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run seed can also be relevant for randomly initialized models, which we sometimes use

parser.add_argument(
"--deterministic",
action="store_true",
help="`deterministic` requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only for cpu/cuda, aiu doesn't have any change

'-v', '--verbose',
action='count',
default=0,
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

Comment on lines +294 to +311
parser.add_argument(
"--max_seq_length",
type=int,
default=384,
help=(
"The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, "
"sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
"--pad_to_max_length",
action="store_true",
help=(
"If passed, pad all samples to `max_seq_length`. "
"Otherwise, dynamic padding is used."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this repeated? or at least very similar to the decoder arguments. It might be worth using argument groups (https://docs.python.org/3/library/argparse.html#argument-groups) and making the roberta and decoder-specific arguments mutually exclusive based on what you pass to some other argument

args = parser.parse_args()

# Add convenient arguments to parser
args.is_encoder = "bert" in args.architecture.lower() # TODO: improve this check
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a real "is_encoder" argument and that would help with the mutually exclusive argument groups

import os

# Third Party
from transformers import PreTrainedModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this import gated and fail only if args.is_quantized is True in inference.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants