Skip to content

Add RoBERTa FP8 support with refactoring #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d57dd0
initial encoder refactoring (wip)
andrea-fasoli Jun 27, 2025
6de2a31
fp8 encoder support
andrea-fasoli Jun 30, 2025
0c6a36b
Update detection of RoBERTa architecture
andrea-fasoli Jun 30, 2025
abc01e5
Remove commented decoder args
andrea-fasoli Jul 1, 2025
0c84b8b
Make verbose a flag
andrea-fasoli Jul 1, 2025
66780e3
Remove TODO in FP8 quantization argument
andrea-fasoli Jul 1, 2025
bde2c60
Remove decoder arguments from argument validation
andrea-fasoli Jul 1, 2025
e8c2cd5
Update argument help
andrea-fasoli Jul 1, 2025
e66cd1c
Update padding explanation
andrea-fasoli Jul 1, 2025
e5555c8
Update linear config message for FP8
andrea-fasoli Jul 1, 2025
af3d681
raise error for default_dtype + quantization
andrea-fasoli Jul 1, 2025
d8b73ee
Update printouts
andrea-fasoli Jul 1, 2025
55f19d8
Update determinism docstring
andrea-fasoli Jul 1, 2025
2d8b88d
Fix typos
andrea-fasoli Jul 1, 2025
429c57f
Update rank-based printouts
andrea-fasoli Jul 1, 2025
041f39a
Remove superseeded roberta.py script
andrea-fasoli Jul 1, 2025
0dfa472
Gate post processing to rank 0 only
andrea-fasoli Jul 1, 2025
bca7a39
Rename encoder inference entry point script
andrea-fasoli Jul 2, 2025
fb3f224
merge from upstream/main
andrea-fasoli Jul 2, 2025
e989a23
Move batch to correct device at eval
andrea-fasoli Jul 2, 2025
0cd0d7b
Add 16b forced casting
andrea-fasoli Jul 2, 2025
18f4230
Reinstate local_size in aiu_setup (for future use)
andrea-fasoli Jul 3, 2025
2fc4c12
Add notes about 384 default max_prompt_length
andrea-fasoli Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import torch
import torch.nn as nn
import time
from fms.utils.tokenizers import BaseTokenizer
from aiu_fms_testing_utils.utils.aiu_setup import dprint
# Standard
from typing import Optional, List, Tuple
import os
import requests
import json
import os
import random
import requests
import time

# Third Party
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from fms.utils.tokenizers import BaseTokenizer
import torch
import torch.nn as nn

def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, use_cache: bool = True, **extra_kwargs):

def warmup_model(
model: nn.Module,
input_ids: torch.Tensor,
max_new_tokens: int,
compile_dynamic_sendnn: bool = False,
use_cache: bool = True,
**extra_kwargs
):
import torch_sendnn
attention_specific_kwargs = {}
attn_name = extra_kwargs["attn_name"]
Expand All @@ -19,7 +30,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
# TODO: Add a unified generation dependent on attn_type
from fms.utils.generation import generate
attention_specific_kwargs["contiguous_cache"] = True

dprint("AIU warmup")
pt_compile_model_time = time.time()

Expand All @@ -31,12 +42,23 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
_max_new_tokens = 2
# always warmup with batch size 2 when using attn_type=paged
if "paged" in attn_name:
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(input_ids, **extra_kwargs)
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(
input_ids,
**extra_kwargs,
)

extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}

with torch_sendnn.warmup_mode():
generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, extra_kwargs=extra_kwargs, **attention_specific_kwargs)
generate(
model,
_warmup_input_ids,
max_new_tokens=_max_new_tokens,
do_sample=False,
use_cache=use_cache,
extra_kwargs=extra_kwargs,
**attention_specific_kwargs,
)
pt_compile_model_time = time.time() - pt_compile_model_time
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

Expand All @@ -52,17 +74,17 @@ def __download_file(url, filename):
try:
response = requests.get(url, stream=True)
response.raise_for_status()

with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Successfully downloaded {filename}")

except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")

def __sample_requests(
prompt_list: List[str],
prompt_list: List[str],
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
Expand All @@ -82,16 +104,14 @@ def __sample_requests(
# Tokenize the prompts and completions.
prompt = prompt_list[i]
prompt_token_ids = ids_for_prompt(prompt, tokenizer)

prompt_len = len(prompt_token_ids)
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
# Prune too short or too long sequences.
continue
filtered_dataset.append((prompt, prompt_len))

return filtered_dataset


return filtered_dataset

def sample_sharegpt_requests(
dataset_path: str,
Expand All @@ -111,15 +131,22 @@ def sample_sharegpt_requests(
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [data["conversations"][0]["value"] for data in dataset]

return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)

return __sample_requests(
dataset,
num_requests,
tokenizer,
prompt_length_min,
prompt_length_max,
seed,
)

def sample_squad_v2_qa_requests(
dataset_path: str,
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
prompt_length_max: int = 64,
num_requests: int,
tokenizer: BaseTokenizer,
prompt_length_min: int = 32,
prompt_length_max: int = 64,
seed: Optional[int] = None
) -> List[Tuple[str, int]]:
from datasets import load_dataset
Expand All @@ -128,10 +155,14 @@ def sample_squad_v2_qa_requests(
ds = load_dataset(dataset_path)['train']
else:
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']


ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)

ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(
ds,
num_requests,
tokenizer,
prompt_length_min,
prompt_length_max,
seed,
)
54 changes: 54 additions & 0 deletions aiu_fms_testing_utils/utils/aiu_setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import torch

# ==============================================================
# Common utilities
Expand Down Expand Up @@ -67,3 +69,55 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
dprint(f"Detected running via torchrun")

aiu_setup(rank, world_size)


# ==============================================================
# Environment variables utilities
# ==============================================================
def set_aiu_env_vars(args: argparse.Namespace) -> None:
"""Set necessary environment variables for AIU"""

if not args.compile_dynamic:
_target_cache_size = max(
int(args.max_new_tokens * 2),
int(args.min_pad_length * 2.5),
int(args.fixed_prompt_length * 2.5),
)
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
_prev = torch._dynamo.config.accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
dprint(
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
f"{args.max_new_tokens}"
)

if _target_cache_size > torch._dynamo.config.cache_size_limit:
_prev = torch._dynamo.config.cache_size_limit
torch._dynamo.config.cache_size_limit = _target_cache_size
dprint(
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
)

torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.automatic_dynamic_shapes = False

# os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") # CONFIRM IF THIS IS NEEDE

if not args.is_encoder:
os.environ.setdefault("COMPILATION_MODE", "offline_decoder")

if args.device_type == "aiu-senulator":
os.environ["FLEX_COMPUTE"] = "SENULATOR"
os.environ["FLEX_DEVICE"] = "MOCK"
else:
if "AIU_WORLD_RANK_0" not in os.environ:
print("must set AIU_WORLD_RANK_0")
exit()
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
os.environ.setdefault("FLEX_DEVICE", "PF") # will use VF eventually
Loading