Skip to content

Refactor inference.py for LLM and RoBERTa support #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6e5c0d2
Refactor argument parsing
andrea-fasoli Apr 23, 2025
5c0d9ec
Refactor model setup
andrea-fasoli Apr 23, 2025
d7730c8
Refactor setup of quantization (addons, linear_config)
andrea-fasoli Apr 23, 2025
819b147
Refactor LLM handling
andrea-fasoli Apr 23, 2025
13c0917
Refactor RoBERTa handling
andrea-fasoli Apr 23, 2025
5895831
Refactor Direct Quantization (wip)
andrea-fasoli Apr 23, 2025
8b1d37e
Refactor inference entry point for LLM and RoBERTa
andrea-fasoli Apr 23, 2025
238b05d
Refactor AIU setup (relocate env vars setup)
andrea-fasoli Apr 23, 2025
38fef7a
Remove deprecated local_size
andrea-fasoli Jun 18, 2025
3b94a36
Remove env vars already set in e2e_stable image
andrea-fasoli Jun 18, 2025
effb27b
Group and update parser arguments
andrea-fasoli Jun 18, 2025
600ba67
Rename enc/dec utils
andrea-fasoli Jun 18, 2025
031abde
Gating some AIU settings
andrea-fasoli Jun 18, 2025
8657a11
Split inference into decoder/encoder scripts (wip)
andrea-fasoli Jun 18, 2025
0d042c6
Fix tokenizer; add some dec/enc args validation
andrea-fasoli Jun 19, 2025
3f13729
Update AIU env var
andrea-fasoli Jun 19, 2025
4e731d8
Minor args update
andrea-fasoli Jun 19, 2025
fd70377
Relocate print_model_params function
andrea-fasoli Jun 19, 2025
7a5c9df
Gate transformers import
andrea-fasoli Jun 19, 2025
3ad7050
Bring recent updates to inference.py into run_decoder.py
andrea-fasoli Jun 19, 2025
92d05ef
Add new sendnn compile arg
andrea-fasoli Jun 19, 2025
011ec33
Remove unified inference.py
andrea-fasoli Jun 19, 2025
5292128
Small fixes
andrea-fasoli Jun 19, 2025
17be9a7
Remove deprecated torch dynamo config option
andrea-fasoli Jun 20, 2025
6780770
Update type hints
andrea-fasoli Jun 20, 2025
0437c50
Update skip compile message
andrea-fasoli Jun 20, 2025
dfd6758
Adjust extra_generation_kwargs handling
andrea-fasoli Jun 20, 2025
f7c458e
Remove INT8 DQ
andrea-fasoli Jun 20, 2025
3434641
Update import of ids_for_prompt and fix some formatting
andrea-fasoli Jun 20, 2025
a543198
Minor changes
andrea-fasoli Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ def __download_file(url, filename):
try:
response = requests.get(url, stream=True)
response.raise_for_status()

with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Successfully downloaded {filename}")

except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")

def __sample_requests(
prompt_list: List[str],
prompt_list: List[str],
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
Expand All @@ -67,16 +67,14 @@ def __sample_requests(
# Tokenize the prompts and completions.
prompt = prompt_list[i]
prompt_token_ids = ids_for_prompt(prompt, tokenizer)

prompt_len = len(prompt_token_ids)
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
# Prune too short or too long sequences.
continue
filtered_dataset.append((prompt, prompt_len))

return filtered_dataset


return filtered_dataset

def sample_sharegpt_requests(
dataset_path: str,
Expand All @@ -96,15 +94,15 @@ def sample_sharegpt_requests(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [data["conversations"][0]["value"] for data in dataset]

return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)

def sample_squad_v2_qa_requests(
dataset_path: str,
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
prompt_length_max: int = 64,
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
prompt_length_max: int = 64,
seed: Optional[int] = None
) -> List[Tuple[str, int]]:
from datasets import load_dataset
Expand All @@ -113,10 +111,10 @@ def sample_squad_v2_qa_requests(
ds = load_dataset(dataset_path)['train']
else:
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']


ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)


60 changes: 56 additions & 4 deletions aiu_fms_testing_utils/utils/aiu_setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import torch

# ==============================================================
# Common utilities
Expand All @@ -21,7 +23,7 @@ def dprint(text):
# ==============================================================
# Common setup
# ==============================================================
def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
def aiu_setup(rank=0, world_size=1, local_rank=0, verbose=False):
# -------------
# Envar setup for Sentient backend
# -------------
Expand Down Expand Up @@ -54,11 +56,9 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
# ==============================================================
# Distributed setup
# ==============================================================
def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
def aiu_dist_setup(rank, world_size, local_rank=-0, verbose=False):
if local_rank < 0:
local_rank = rank
if local_size < 0:
local_size = world_size

if os.getenv("TORCHELASTIC_RUN_ID") is None:
os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -67,3 +67,55 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
dprint(f"Detected running via torchrun")

aiu_setup(rank, world_size)


# ==============================================================
# Environment variables utilities
# ==============================================================
def set_aiu_env_vars(args: argparse.Namespace) -> None:
"""Set necessary environment variables for AIU"""

if not args.compile_dynamic:
_target_cache_size = max(
int(args.max_new_tokens * 2),
int(args.min_pad_length * 2.5),
int(args.fixed_prompt_length * 2.5),
)
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
dprint(
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
f"{args.max_new_tokens}"
)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
dprint(
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
)

torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.automatic_dynamic_shapes = False

# os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") # CONFIRM IF THIS IS NEEDE

if not args.is_encoder:
os.environ.setdefault("COMPILATION_MODE", "offline_decoder")

if args.device_type == "aiu-senulator":
os.environ["FLEX_COMPUTE"] = "SENULATOR"
os.environ["FLEX_DEVICE"] = "MOCK"
else:
if "AIU_WORLD_RANK_0" not in os.environ:
print("must set AIU_WORLD_RANK_0")
exit()
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
os.environ.setdefault("FLEX_DEVICE", "PF") # will use VF eventually
Loading