Skip to content

Refactor inference.py for LLM and RoBERTa support #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

andrea-fasoli
Copy link
Collaborator

@andrea-fasoli andrea-fasoli commented Apr 23, 2025

This PR implements a substantial refactoring of inference.py which becomes the single entry point for LLMs and RoBERTa models. Support covers non-quantized, GPTQ W4A16, and INT8 models.

inference.py code has been streamlined. It is now structured into the following sections:

args_parsing            define script arguments across all model configurations
aiu_setup               set up AIU environment variables
model_setup             define model dtype, device, and distributed strategy
quantization_setup      import FMS-MO addons and define linear_config for FMS
direct_quantization     quantize a non-quantized model to INT8 (WIP)
decoders                run token generation task with LLMs
encoders                run QA or MLM task with RoBERTa

Extensive code validation is needed prior merging.

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
@andrea-fasoli
Copy link
Collaborator Author

cc for review: @ani300 @JRosenkranz

@@ -57,6 +59,8 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
if local_rank < 0:
local_rank = rank

# FIXME: local_size not in use ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's no longer needed? if you can't find any reference to it feel free to delete

Comment on lines 82 to 106
_target_cache_size = max(
int(args.max_new_tokens * 2),
int(args.min_pad_length * 2.5),
int(args.fixed_prompt_length * 2.5),
)
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
dprint(
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
f"{args.max_new_tokens}"
)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
dprint(
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only needed if compile_dynamic is disabled, can we gate it?

Comment on lines 114 to 117
os.environ.setdefault("SENCORES", "32")
os.environ.setdefault("SENCORELETS", "2")
os.environ.setdefault("DATA_PREC", "fp16")
os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of these are already set by default on the e2e_stable image, can we check and remove the ones we don't need anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed these are no longer needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") still needed or not? It's the env var that was set after these ones

os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1")
os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true")

os.environ.setdefault("COMPILATION_MODE", "offline_decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is only needed for decoder models, for roberta it will probably make it not work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed it will not work for roberta, we need to set it depending on the kind of model

print("must set AIU_WORLD_RANK_0")
exit()
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
os.environ.setdefault("FLEX_DEVICE", "VFIO")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think VFIO is now PF or VF, depending on the AIU setup, we can set PF by default as most cards are running in PF mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed it is now PF for all clusters we have access to and will eventually be VF

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it to PF but do we need an argument for this?
One using: choices=["VF", "PF"], default="PF"

action="store_true",
help=(
"If set to True, this will unfuse any fused weight modules that "
"support the unfuse_weights method"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be upgraded to "If set to True, this will unfuse any fused weights in the model" as the way it's done doesn't involve "unfuse_weights" anymore

"--seed",
type=int,
default=81072,
help="Run seed (only needed if eval dataset is shuffled)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run seed can also be relevant for randomly initialized models, which we sometimes use

parser.add_argument(
"--deterministic",
action="store_true",
help="`deterministic` requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only for cpu/cuda, aiu doesn't have any change

'-v', '--verbose',
action='count',
default=0,
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly in the quantization functions of int8 roberta for now (printing out model parameters if so desired - it's crucial for debug), but I supposed it could be a useful flag for decoders too. I am not using the count functionality at this time, could be just a True/False flag.

Comment on lines 294 to 311
parser.add_argument(
"--max_seq_length",
type=int,
default=384,
help=(
"The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, "
"sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
"--pad_to_max_length",
action="store_true",
help=(
"If passed, pad all samples to `max_seq_length`. "
"Otherwise, dynamic padding is used."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this repeated? or at least very similar to the decoder arguments. It might be worth using argument groups (https://docs.python.org/3/library/argparse.html#argument-groups) and making the roberta and decoder-specific arguments mutually exclusive based on what you pass to some other argument

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is some overlap indeed.
I combined max_seq_length (enc) with max_prompt_len (dec), as they share the same meaning (although the default values differed).
pad_to_max_length (enc) conceptually overlaps with min_pad_length and fixed_prompt_length (dec), but the first is a boolean while the others are int. I couldn't find a clean way to implement a 3-way exclusivity but I'll add some argument validation at the time of loading encoder vs. decoder.

args = parser.parse_args()

# Add convenient arguments to parser
args.is_encoder = "bert" in args.architecture.lower() # TODO: improve this check
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a real "is_encoder" argument and that would help with the mutually exclusive argument groups

import os

# Third Party
from transformers import PreTrainedModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this import gated and fail only if args.is_quantized is True in inference.py


if not args.compile_dynamic:
torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.dynamic_shapes = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now deprecated in pytorch, only assume and automatic are needed

parser.add_argument(
"--quantization",
type=str,
choices=["gptq", "int8"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a TODO to add FP8 inference once that lands too

help="Enable smoothquant in INT8 quantized model",
)
parser.add_argument( # NOTE: roberta only so far but should expand to LLM
"--direct_quantization",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the int8 prefix

help="Train INT8 model with Direct Quantization",
)
parser.add_argument(
"--num_dq_samples",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the int8 prefix

"--compile_dynamic",
action="store_true",
help="Use dynamic shapes with torch.compile",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a new --compile_dynamic_sendnn that needs to be added here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we are cleaning up the args, I wonder if it would make sense to combine compile_dynamic and compile_dynamic_sendnn. What would happen if a user does compile_dynamic_sendnn and not compile_dynamic? We may want to make this something like --compile_dynamic=<static_inputs, symbolic_inputs>

),
)

# RoBERTa-specific evaluation arguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a prefix to mark these as encoder-specific?

Comment on lines 25 to 26
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are wrong types

Copy link
Contributor

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's plenty of things to change, and I'd like to think more on the general architecture... Given how different the parameters and general flow is for both encoder tasks and decoders, it might be worth splitting inference.py into encoders.py and decoders.py. Most of the code can be reused anyways if the arguments are groups using the API for this and then just picking the relevant groups for each script.

Of course, documentation (README.md) also needs to be updated for this new structure.

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
help="A csv or a json file containing the validation data.",
)
args_encoder.add_argument(
"--pad_to_max_length",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use min_pad_length here? If the min_pad_length is not specified, then it will implicitly pad to max length. If min_pad_length is specified and is larger than the largest sequence, this will add extra pads (if we want to simulate a different sequence length)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put some thoughts into this and it is surprisingly tricky because for encoders the tokenization is performed under the hood of a transformers' PretrainedTokenizer, which does not handle truncation and padding the same way as FMS (i.e, with explicit calls to FMS truncate_prompts_to_max_length and pad_input_ids).

PretrainedTokenizer receives a max_length argument for truncation AND padding, and a padding argument which can be True (pad to max sequence), False (do not pad), or string "max_length" (which will pad to the max_length argument).

To make it behave like our decoder tokenization, we would need to adjust max_length and padding based on the tokenized sequence length... which is not known yet at the time of PretrainedTokenizer call.

We could eventually change the whole tokenization and feature preparation process (which right now is mostly based on a pytorch example, would be nice to rework it), but for the time being I would keep pad_to_max_length argument for encoders only.


tokens = self.tokenizer.tokenize(prompt)
ids = self.tokenizer.convert_tokens_to_ids(tokens)
if self.add_special_tokens:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this handling the case where tokenizer.bos_token_id != tokenizer.eos_token_id?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reproduced this from inference.py. Only difference is the location of where self.add_special_tokens is defined. Now it's updated at the beginning of process_eval_set (before any ids_for_prompt call) as:

        self.add_special_tokens = (
            self.tokenizer.bos_token_id != self.tokenizer.eos_token_id
        )

Should be correct.

f"Architecture {args.architecture} should be run as an encoder model."
)

def ids_for_prompt(self, prompt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We maybe able to reuse the function ids_for_prompt in utils/__init__.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, it's the same function. It is duplicated also in the current inference.py...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at a second thought, this function may belong to the DecoderInfer class, as the encoders use a different approach. I could remove the duplicate from utils/__init__.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, ids_for_prompt in utils/__init__.py is also used to generate ids for testing. inference.py and validation.py duplicate this function though. We should sort this out

raise ValueError(
"Running encoder model but is_encoder argument is either not set or False."
)
if args.min_pad_length != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the same arguments for this as decoder rather than introducing pad_to_max_length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see answer above regarding tokenization in encoders

Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants