Skip to content

Refactor inference.py for LLM and RoBERTa support #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6e5c0d2
Refactor argument parsing
andrea-fasoli Apr 23, 2025
5c0d9ec
Refactor model setup
andrea-fasoli Apr 23, 2025
d7730c8
Refactor setup of quantization (addons, linear_config)
andrea-fasoli Apr 23, 2025
819b147
Refactor LLM handling
andrea-fasoli Apr 23, 2025
13c0917
Refactor RoBERTa handling
andrea-fasoli Apr 23, 2025
5895831
Refactor Direct Quantization (wip)
andrea-fasoli Apr 23, 2025
8b1d37e
Refactor inference entry point for LLM and RoBERTa
andrea-fasoli Apr 23, 2025
238b05d
Refactor AIU setup (relocate env vars setup)
andrea-fasoli Apr 23, 2025
38fef7a
Remove deprecated local_size
andrea-fasoli Jun 18, 2025
3b94a36
Remove env vars already set in e2e_stable image
andrea-fasoli Jun 18, 2025
effb27b
Group and update parser arguments
andrea-fasoli Jun 18, 2025
600ba67
Rename enc/dec utils
andrea-fasoli Jun 18, 2025
031abde
Gating some AIU settings
andrea-fasoli Jun 18, 2025
8657a11
Split inference into decoder/encoder scripts (wip)
andrea-fasoli Jun 18, 2025
0d042c6
Fix tokenizer; add some dec/enc args validation
andrea-fasoli Jun 19, 2025
3f13729
Update AIU env var
andrea-fasoli Jun 19, 2025
4e731d8
Minor args update
andrea-fasoli Jun 19, 2025
fd70377
Relocate print_model_params function
andrea-fasoli Jun 19, 2025
7a5c9df
Gate transformers import
andrea-fasoli Jun 19, 2025
3ad7050
Bring recent updates to inference.py into run_decoder.py
andrea-fasoli Jun 19, 2025
92d05ef
Add new sendnn compile arg
andrea-fasoli Jun 19, 2025
011ec33
Remove unified inference.py
andrea-fasoli Jun 19, 2025
5292128
Small fixes
andrea-fasoli Jun 19, 2025
17be9a7
Remove deprecated torch dynamo config option
andrea-fasoli Jun 20, 2025
6780770
Update type hints
andrea-fasoli Jun 20, 2025
0437c50
Update skip compile message
andrea-fasoli Jun 20, 2025
dfd6758
Adjust extra_generation_kwargs handling
andrea-fasoli Jun 20, 2025
f7c458e
Remove INT8 DQ
andrea-fasoli Jun 20, 2025
3434641
Update import of ids_for_prompt and fix some formatting
andrea-fasoli Jun 20, 2025
a543198
Minor changes
andrea-fasoli Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions aiu_fms_testing_utils/utils/aiu_setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import torch

# ==============================================================
# Common utilities
Expand Down Expand Up @@ -57,6 +59,8 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
if local_rank < 0:
local_rank = rank

# FIXME: local_size not in use ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's no longer needed? if you can't find any reference to it feel free to delete

if local_size < 0:
local_size = world_size

Expand All @@ -67,3 +71,60 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
dprint(f"Detected running via torchrun")

aiu_setup(rank, world_size)


# ==============================================================
# Environment variables utilities
# ==============================================================
def set_aiu_env_vars(args: argparse.Namespace) -> None:
"""Set necessary environment variables for AIU"""

_target_cache_size = max(
int(args.max_new_tokens * 2),
int(args.min_pad_length * 2.5),
int(args.fixed_prompt_length * 2.5),
)
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
dprint(
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
f"{args.max_new_tokens}"
)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
dprint(
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only needed if compile_dynamic is disabled, can we gate it?


if not args.compile_dynamic:
torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.dynamic_shapes = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now deprecated in pytorch, only assume and automatic are needed

torch._dynamo.config.automatic_dynamic_shapes = False

# This should be set outside!!!
os.environ.setdefault("SENCORES", "32")
os.environ.setdefault("SENCORELETS", "2")
os.environ.setdefault("DATA_PREC", "fp16")
os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of these are already set by default on the e2e_stable image, can we check and remove the ones we don't need anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed these are no longer needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") still needed or not? It's the env var that was set after these ones

os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true")

os.environ.setdefault("COMPILATION_MODE", "offline_decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is only needed for decoder models, for roberta it will probably make it not work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed it will not work for roberta, we need to set it depending on the kind of model


if args.device_type == "aiu-senulator":
os.environ["FLEX_COMPUTE"] = "SENULATOR"
os.environ["FLEX_DEVICE"] = "MOCK"
else:
if "AIU_WORLD_RANK_0" not in os.environ:
print("must set AIU_WORLD_RANK_0")
exit()
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
os.environ.setdefault("FLEX_DEVICE", "VFIO")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think VFIO is now PF or VF, depending on the AIU setup, we can set PF by default as most cards are running in PF mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed it is now PF for all clusters we have access to and will eventually be VF

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it to PF but do we need an argument for this?
One using: choices=["VF", "PF"], default="PF"

Loading