Skip to content

Commit 8d703be

Browse files
authored
Merge pull request #72 from andrea-fasoli/fp8_qa
Add RoBERTa FP8 support with refactoring
2 parents ae77c06 + 2fc4c12 commit 8d703be

File tree

8 files changed

+1560
-189
lines changed

8 files changed

+1560
-189
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
import torch
2-
import torch.nn as nn
3-
import time
4-
from fms.utils.tokenizers import BaseTokenizer
5-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1+
# Standard
62
from typing import Optional, List, Tuple
7-
import os
8-
import requests
93
import json
4+
import os
105
import random
6+
import requests
7+
import time
8+
9+
# Third Party
10+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
11+
from fms.utils.tokenizers import BaseTokenizer
12+
import torch
13+
import torch.nn as nn
1114

12-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, use_cache: bool = True, **extra_kwargs):
15+
16+
def warmup_model(
17+
model: nn.Module,
18+
input_ids: torch.Tensor,
19+
max_new_tokens: int,
20+
compile_dynamic_sendnn: bool = False,
21+
use_cache: bool = True,
22+
**extra_kwargs
23+
):
1324
import torch_sendnn
1425
attention_specific_kwargs = {}
1526
attn_name = extra_kwargs["attn_name"]
@@ -19,7 +30,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
1930
# TODO: Add a unified generation dependent on attn_type
2031
from fms.utils.generation import generate
2132
attention_specific_kwargs["contiguous_cache"] = True
22-
33+
2334
dprint("AIU warmup")
2435
pt_compile_model_time = time.time()
2536

@@ -31,12 +42,23 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
3142
_max_new_tokens = 2
3243
# always warmup with batch size 2 when using attn_type=paged
3344
if "paged" in attn_name:
34-
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(input_ids, **extra_kwargs)
45+
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(
46+
input_ids,
47+
**extra_kwargs,
48+
)
3549

3650
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
3751

3852
with torch_sendnn.warmup_mode():
39-
generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, extra_kwargs=extra_kwargs, **attention_specific_kwargs)
53+
generate(
54+
model,
55+
_warmup_input_ids,
56+
max_new_tokens=_max_new_tokens,
57+
do_sample=False,
58+
use_cache=use_cache,
59+
extra_kwargs=extra_kwargs,
60+
**attention_specific_kwargs,
61+
)
4062
pt_compile_model_time = time.time() - pt_compile_model_time
4163
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
4264

@@ -52,17 +74,17 @@ def __download_file(url, filename):
5274
try:
5375
response = requests.get(url, stream=True)
5476
response.raise_for_status()
55-
77+
5678
with open(filename, 'wb') as file:
5779
for chunk in response.iter_content(chunk_size=8192):
5880
file.write(chunk)
5981
print(f"Successfully downloaded {filename}")
60-
82+
6183
except requests.exceptions.RequestException as e:
6284
print(f"An error occurred: {e}")
6385

6486
def __sample_requests(
65-
prompt_list: List[str],
87+
prompt_list: List[str],
6688
num_requests: int,
6789
tokenizer: BaseTokenizer,
6890
prompt_length_min: int = 32,
@@ -82,16 +104,14 @@ def __sample_requests(
82104
# Tokenize the prompts and completions.
83105
prompt = prompt_list[i]
84106
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
85-
107+
86108
prompt_len = len(prompt_token_ids)
87109
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
88110
# Prune too short or too long sequences.
89111
continue
90112
filtered_dataset.append((prompt, prompt_len))
91-
92-
return filtered_dataset
93-
94113

114+
return filtered_dataset
95115

96116
def sample_sharegpt_requests(
97117
dataset_path: str,
@@ -111,15 +131,22 @@ def sample_sharegpt_requests(
111131
# Filter out the conversations with less than 2 turns.
112132
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
113133
dataset = [data["conversations"][0]["value"] for data in dataset]
114-
115-
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
134+
135+
return __sample_requests(
136+
dataset,
137+
num_requests,
138+
tokenizer,
139+
prompt_length_min,
140+
prompt_length_max,
141+
seed,
142+
)
116143

117144
def sample_squad_v2_qa_requests(
118145
dataset_path: str,
119-
num_requests: int,
120-
tokenizer: BaseTokenizer,
121-
prompt_length_min: int = 32,
122-
prompt_length_max: int = 64,
146+
num_requests: int,
147+
tokenizer: BaseTokenizer,
148+
prompt_length_min: int = 32,
149+
prompt_length_max: int = 64,
123150
seed: Optional[int] = None
124151
) -> List[Tuple[str, int]]:
125152
from datasets import load_dataset
@@ -128,10 +155,14 @@ def sample_squad_v2_qa_requests(
128155
ds = load_dataset(dataset_path)['train']
129156
else:
130157
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
131-
132-
133-
ds = [f"{data['context']}\n{data['question']}" for data in ds]
134158

135-
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
136-
159+
ds = [f"{data['context']}\n{data['question']}" for data in ds]
137160

161+
return __sample_requests(
162+
ds,
163+
num_requests,
164+
tokenizer,
165+
prompt_length_min,
166+
prompt_length_max,
167+
seed,
168+
)

aiu_fms_testing_utils/utils/aiu_setup.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import argparse
12
import os
3+
import torch
24

35
# ==============================================================
46
# Common utilities
@@ -67,3 +69,55 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
6769
dprint(f"Detected running via torchrun")
6870

6971
aiu_setup(rank, world_size)
72+
73+
74+
# ==============================================================
75+
# Environment variables utilities
76+
# ==============================================================
77+
def set_aiu_env_vars(args: argparse.Namespace) -> None:
78+
"""Set necessary environment variables for AIU"""
79+
80+
if not args.compile_dynamic:
81+
_target_cache_size = max(
82+
int(args.max_new_tokens * 2),
83+
int(args.min_pad_length * 2.5),
84+
int(args.fixed_prompt_length * 2.5),
85+
)
86+
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
87+
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
88+
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
89+
_prev = torch._dynamo.config.accumulated_cache_size_limit
90+
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
91+
dprint(
92+
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
93+
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
94+
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
95+
f"{args.max_new_tokens}"
96+
)
97+
98+
if _target_cache_size > torch._dynamo.config.cache_size_limit:
99+
_prev = torch._dynamo.config.cache_size_limit
100+
torch._dynamo.config.cache_size_limit = _target_cache_size
101+
dprint(
102+
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
103+
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
104+
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
105+
)
106+
107+
torch._dynamo.config.assume_static_by_default = True
108+
torch._dynamo.config.automatic_dynamic_shapes = False
109+
110+
# os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") # CONFIRM IF THIS IS NEEDE
111+
112+
if not args.is_encoder:
113+
os.environ.setdefault("COMPILATION_MODE", "offline_decoder")
114+
115+
if args.device_type == "aiu-senulator":
116+
os.environ["FLEX_COMPUTE"] = "SENULATOR"
117+
os.environ["FLEX_DEVICE"] = "MOCK"
118+
else:
119+
if "AIU_WORLD_RANK_0" not in os.environ:
120+
print("must set AIU_WORLD_RANK_0")
121+
exit()
122+
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
123+
os.environ.setdefault("FLEX_DEVICE", "PF") # will use VF eventually

0 commit comments

Comments
 (0)