Skip to content

Commit ae77c06

Browse files
Merge pull request #54 from foundation-model-stack/fp8
Add fp8 attention support for AIU
2 parents 4619e62 + 073fd75 commit ae77c06

File tree

5 files changed

+122
-45
lines changed

5 files changed

+122
-45
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def load_validation_information(validation_path, validation_files_type, batch_si
187187

188188
return ValidationInfo(validation_info)
189189

190-
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", attn_type="sdpa", **padding_kwargs):
190+
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **extra_kwargs):
191191
max_seq_len = model.config.max_expected_seq_len
192192
attention_specific_kwargs = {}
193-
if attn_type == "paged":
193+
if "paged" in extra_kwargs["attn_name"]:
194194
from aiu_fms_testing_utils.utils.paged import generate
195195
else:
196196
# TODO: Add a unified generation dependent on attn_type
@@ -199,7 +199,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
199199
attention_specific_kwargs["max_seq_len"] = max_seq_len
200200

201201
# Add only_last_token optimization
202-
extra_generation_kwargs = {**padding_kwargs}
202+
extra_generation_kwargs = {**extra_kwargs}
203203
if only_last_token:
204204
extra_generation_kwargs["only_last_token"] = only_last_token
205205
if attn_algorithm is not None:

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import json
1010
import random
1111

12-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, attn_type="sdpa", **padding_kwargs):
12+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, use_cache: bool = True, **extra_kwargs):
1313
import torch_sendnn
1414
attention_specific_kwargs = {}
15-
if attn_type == "paged":
15+
attn_name = extra_kwargs["attn_name"]
16+
if "paged" in attn_name:
1617
from aiu_fms_testing_utils.utils.paged import generate, adjust_inputs_to_batch
1718
else:
1819
# TODO: Add a unified generation dependent on attn_type
@@ -24,18 +25,18 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
2425

2526
# adjust inputs depending on attn_type and dynamic shapes
2627
_warmup_input_ids = input_ids
27-
_padding_kwargs = padding_kwargs
28+
_extra_kwargs = extra_kwargs
2829
_max_new_tokens = max_new_tokens
2930
if compile_dynamic_sendnn:
3031
_max_new_tokens = 2
3132
# always warmup with batch size 2 when using attn_type=paged
32-
if attn_type == "paged":
33-
_warmup_input_ids, _padding_kwargs = adjust_inputs_to_batch(input_ids, **padding_kwargs)
33+
if "paged" in attn_name:
34+
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(input_ids, **extra_kwargs)
3435

35-
extra_kwargs = {**_padding_kwargs, "only_last_token": attn_type != "paged"}
36+
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
3637

3738
with torch_sendnn.warmup_mode():
38-
generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, use_cache=True, do_sample=False, extra_kwargs=extra_kwargs, **attention_specific_kwargs)
39+
generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, extra_kwargs=extra_kwargs, **attention_specific_kwargs)
3940
pt_compile_model_time = time.time() - pt_compile_model_time
4041
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
4142

aiu_fms_testing_utils/utils/paged.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import fms.utils.spyre.paged
77

8-
def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):
8+
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
99
"""
1010
Adjusts the inputs to a batch. Batch size 1 cannot be handled since we want a symbolic shape for the batch
1111
and pytorch automatically sets size 1 dimensions as static
@@ -14,11 +14,11 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):
1414
"""
1515
input_ids = input_ids[0].repeat(2, 1)
1616
# ensure we pass along other kwargs
17-
kwargs = {**padding_kwargs}
18-
mask = padding_kwargs.get("mask", None)
17+
kwargs = {**extra_kwargs}
18+
mask = extra_kwargs.get("mask", None)
1919
if mask is not None:
2020
kwargs["mask"] = torch.stack((mask[0], mask[0]))
21-
position_ids = padding_kwargs.get("position_ids", None)
21+
position_ids = extra_kwargs.get("position_ids", None)
2222
if position_ids is not None:
2323
kwargs["position_ids"] = position_ids[0].repeat(2, 1)
2424
return input_ids, kwargs
@@ -137,14 +137,23 @@ def generate(
137137

138138
kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads
139139
head_size = model.config.emb_dim // nheads
140-
kwargs["attn_name"] = "spyre_paged_attn"
141-
kwargs["past_key_value_states"] = [
142-
(
143-
torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
144-
torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
145-
)
146-
for _ in range(model.config.nlayers)
147-
]
140+
if "fp8" in kwargs["attn_name"]:
141+
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
142+
kwargs["past_key_value_states"] = [
143+
(
144+
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
145+
ScaledTensor(torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=torch.float8_e4m3fn), torch.tensor(1.0), False),
146+
)
147+
for _ in range(model.config.nlayers)
148+
]
149+
else:
150+
kwargs["past_key_value_states"] = [
151+
(
152+
torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
153+
torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
154+
)
155+
for _ in range(model.config.nlayers)
156+
]
148157
kwargs["block_table"] = None
149158
block_numbers = [i for i in range(NUM_BLOCKS)]
150159
# this will ensure we don't have contiguous blocks
@@ -240,6 +249,12 @@ def generate(
240249
attn_name=kwargs["attn_name"],
241250
)
242251

252+
# TODO: Figure out how to do this cleanly
253+
if "fp8" in kwargs["attn_name"] and seq_i != input_ids.size(0) - 1:
254+
for layer_cache in current_kv_cache:
255+
layer_cache[0]._scaled = False
256+
layer_cache[1]._scaled = False
257+
243258
outputs_list.append(output[0].squeeze(0))
244259

245260
output = (torch.stack(outputs_list), current_kv_cache)

scripts/inference.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pathlib import Path
88
import random
99
import time
10-
import contextlib
1110

1211
# Third Party
1312
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
@@ -104,7 +103,17 @@
104103
type=str,
105104
default=None,
106105
choices=["bf16", "fp16", "fp32"],
107-
help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format",
106+
help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format. This will break quantized checkpoints.",
107+
)
108+
parser.add_argument(
109+
"--cast_bf16_to_fp16",
110+
action="store_true",
111+
help="If set, cast any bf16 weights in the model to fp16 for AIU compiler. Doesn't touch fp32 or quantized",
112+
)
113+
parser.add_argument(
114+
"--cast_fp16_to_bf16",
115+
action="store_true",
116+
help="If set, cast any fp16 weights in the model to bf16 for GPU. Doesn't touch fp32 or quantized",
108117
)
109118
parser.add_argument(
110119
"--compile",
@@ -221,17 +230,29 @@
221230
parser.add_argument(
222231
"--attention_type",
223232
type=str,
224-
choices=["sdpa", "paged"],
233+
choices=["sdpa", "paged", "math_fp8", "paged_fp8"],
225234
default="sdpa",
226235
help="which backend attention to use in mha",
227236
)
228237
args = parser.parse_args()
229238

230-
if args.attention_type == "paged":
239+
attention_map = {
240+
"sdpa": "sdpa_causal",
241+
"paged": "spyre_paged_attn",
242+
"math_fp8": "math_fp8",
243+
"paged_fp8": "spyre_paged_attn_fp8",
244+
}
245+
246+
attn_name = attention_map[args.attention_type]
247+
248+
if "paged" in attn_name:
231249
from aiu_fms_testing_utils.utils.paged import generate
232250
else:
233251
from fms.utils.generation import generate
234252

253+
if "fp8" in attn_name:
254+
import fms_mo.aiu_addons.fp8.fp8_attn
255+
235256
if args.quantization == "gptq":
236257
if "aiu" in args.device_type:
237258
try:
@@ -329,7 +350,7 @@
329350
print("must set AIU_WORLD_RANK_0")
330351
exit()
331352
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
332-
os.environ.setdefault("FLEX_DEVICE", "VFIO")
353+
os.environ.setdefault("FLEX_DEVICE", "PF")
333354

334355
device = torch.device("cpu")
335356
else:
@@ -463,6 +484,38 @@ def select_int8_module(
463484
fused_weights=fused_weights,
464485
)
465486

487+
### Quantization
488+
489+
# FP8 model checks
490+
has_fp8_weights = False
491+
has_bf16_weights = False
492+
has_fp16_weights = False
493+
for param in model.parameters():
494+
if param.dtype == torch.float8_e4m3fn:
495+
has_fp8_weights = True
496+
elif param.dtype == torch.bfloat16:
497+
has_bf16_weights = True
498+
elif param.dtype == torch.float16:
499+
has_fp16_weights = True
500+
501+
if has_fp8_weights:
502+
if is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16:
503+
raise ValueError("FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!")
504+
elif device.type == "cuda" and has_fp16_weights and not args.cast_fp16_to_bf16:
505+
raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!")
506+
507+
if args.cast_bf16_to_fp16:
508+
for name, param in model.named_parameters():
509+
if param.dtype == torch.bfloat16:
510+
if param.max() > torch.finfo(torch.float16).max:
511+
dprint(f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended.")
512+
param.data = param.data.to(dtype=torch.float16)
513+
514+
if args.cast_fp16_to_bf16:
515+
for param in model.parameters():
516+
if param.dtype == torch.float16:
517+
param.data = param.data.to(dtype=torch.bfloat16)
518+
466519
if args.quantization in ["gptq", "int8"]:
467520
if rank == 0 and args.verbose > 0:
468521
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))
@@ -606,7 +659,9 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
606659
ids = prompts
607660
if isinstance(ids, list) and len(ids) == 1:
608661
ids = ids[0].unsqueeze(0)
609-
extra_generation_kwargs = None
662+
extra_generation_kwargs = {}
663+
664+
extra_generation_kwargs["attn_name"] = attn_name
610665

611666

612667
def print_result(result, result_idx: int):
@@ -648,19 +703,15 @@ def infer(use_cache, do_sample, warmup):
648703
global extra_generation_kwargs
649704
if extra_generation_kwargs is None:
650705
extra_generation_kwargs = {}
651-
extra_generation_kwargs["only_last_token"] = args.attention_type != "paged"
652-
653-
if args.device_type == "cpu":
654-
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu impl when padding too much
655-
extra_generation_kwargs["attn_algorithm"] = "math"
706+
extra_generation_kwargs["only_last_token"] = "paged" not in attn_name
656707

657708
if not args.no_early_termination and not warmup:
658709
eos_token_id = tokenizer.eos_token_id
659710
else:
660711
eos_token_id = None
661712

662713
attention_specific_kwargs = {}
663-
if args.attention_type == "sdpa":
714+
if attn_name == "sdpa_causal":
664715
attention_specific_kwargs["contiguous_cache"] = True
665716

666717
result = generate(
@@ -706,7 +757,8 @@ def infer(use_cache, do_sample, warmup):
706757
dprint(f"compilation warmup")
707758
pt_compile_model_time = time.time()
708759
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
709-
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, attn_type=args.attention_type, **extra_generation_kwargs)
760+
for cache in use_cache:
761+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
710762
aiu_warmup_time = time.time()
711763
for sample, cache in itertools.product(do_sample, use_cache):
712764
infer(cache, sample, True)

tests/models/test_decoders.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@
5858
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
5959

6060
ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa")
61+
attention_map = {
62+
"sdpa": "sdpa_causal",
63+
"paged": "spyre_paged_attn",
64+
"math_fp8": "math_fp8",
65+
"paged_fp8": "spyre_paged_attn_fp8",
66+
}
67+
ATTN_NAME = attention_map[ATTN_TYPE]
68+
6169
FORCE_VALIDATION_LEVEL_1 = (
6270
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
6371
)
@@ -246,8 +254,8 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
246254
for prompt, _ in prompts_and_sizes:
247255
prompt_list.append(ids_for_prompt(prompt, tokenizer))
248256

249-
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
250-
return input_ids, padding_kwargs
257+
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
258+
return input_ids, extra_kwargs
251259

252260

253261
def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens):
@@ -413,10 +421,11 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
413421
)
414422

415423
# prepare input_ids
416-
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
424+
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
425+
extra_kwargs["attn_name"] = ATTN_NAME
417426

418427
# warmup aiu model
419-
warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, attn_type=ATTN_TYPE, **padding_kwargs)
428+
warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs)
420429

421430
# generate cpu validation info
422431
cpu_validation_info = __load_validation_info(
@@ -429,7 +438,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
429438
max_new_tokens,
430439
LogitsExtractorHook(),
431440
attn_algorithm="math",
432-
**padding_kwargs,
441+
**extra_kwargs,
433442
)
434443

435444
if save_validation_info_outputs:
@@ -448,7 +457,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
448457

449458
# first test validation level 0
450459
aiu_validation_info = extract_validation_information(
451-
model, input_ids, max_new_tokens, None, only_last_token=ATTN_TYPE != "paged", attn_type=ATTN_TYPE, **padding_kwargs
460+
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs
452461
)
453462
dprint("aiu validation info extracted for validation level 0")
454463

@@ -487,9 +496,10 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
487496
for i in range(iters):
488497
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
489498
if i != 0:
490-
input_ids, padding_kwargs = __prepare_inputs(
499+
input_ids, extra_kwargs = __prepare_inputs(
491500
batch_size, seq_length, tokenizer, seed=i
492501
)
502+
extra_kwargs["attn_name"] = ATTN_NAME
493503
cpu_validation_info = __load_validation_info(
494504
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
495505
)
@@ -500,7 +510,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
500510
max_new_tokens,
501511
LogitsExtractorHook(),
502512
attn_algorithm="math",
503-
**padding_kwargs,
513+
**extra_kwargs,
504514
)
505515
dprint(
506516
f"cpu validation info extracted for validation level 1 - iter={i}"
@@ -526,8 +536,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
526536
max_new_tokens,
527537
GoldenTokenHook(cpu_static_tokens),
528538
only_last_token=ATTN_TYPE != "paged",
529-
attn_type=ATTN_TYPE,
530-
**padding_kwargs,
539+
**extra_kwargs,
531540
)
532541
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
533542
if save_validation_info_outputs:

0 commit comments

Comments
 (0)