Skip to content

Commit 0cd0d7b

Browse files
committed
Add 16b forced casting
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent e989a23 commit 0cd0d7b

File tree

5 files changed

+90
-3
lines changed

5 files changed

+90
-3
lines changed

aiu_fms_testing_utils/utils/args_parsing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
5050
"weight format by setting the default pytorch format"
5151
),
5252
)
53+
parser.add_argument(
54+
"--cast_bf16_to_fp16",
55+
action="store_true",
56+
help=(
57+
"If set, cast any bf16 weights in the model to fp16 for AIU compiler. "
58+
"Doesn't touch fp32 or quantized"
59+
)
60+
)
61+
parser.add_argument(
62+
"--cast_fp16_to_bf16",
63+
action="store_true",
64+
help=(
65+
"If set, cast any fp16 weights in the model to bf16 for GPU. "
66+
"Doesn't touch fp32 or quantized"
67+
)
68+
)
5369

5470
# Quantization arguments
5571
args_quantization = parser.add_argument_group("Model quantization")
@@ -260,6 +276,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
260276
args.is_aiu_backend = "aiu" in args.device_type
261277
args.dynamo_backend = "sendnn" if args.is_aiu_backend else "inductor"
262278
args.fused_weights = not args.unfuse_weights
279+
args.force_16b_dtype = args.cast_bf16_to_fp16 or args.cast_fp16_to_bf16
263280

264281
if args.verbose:
265282
dprint("=" * 60)

aiu_fms_testing_utils/utils/encoders_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,8 @@ def run_evaluation(self) -> None:
597597
batch = self.convert_batch_to_fms_style(batch)
598598
batch = move_to_device(batch, args.device)
599599
start_logits, end_logits = self.model(**batch)
600-
all_start_logits.append(start_logits.cpu().numpy())
601-
all_end_logits.append(end_logits.cpu().numpy())
600+
all_start_logits.append(start_logits.to(torch.float16).cpu().numpy())
601+
all_end_logits.append(end_logits.to(torch.float16).cpu().numpy())
602602
eval_duration = time.time() - start_time
603603
dprint(
604604
f"Runtime: {eval_duration:.0f} s | "

aiu_fms_testing_utils/utils/model_setup.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,33 @@ def setup_model(args: argparse.Namespace) -> tuple[str | None, torch.device, str
120120
return default_dtype, device, dist_strat
121121

122122

123+
def recast_16b(model: nn.Module, args: argparse.Namespace) -> None:
124+
"""Cast 16-bit model parameters to selected datatype."""
125+
126+
if args.cast_bf16_to_fp16:
127+
dprint(
128+
"Casting all BF16 model parameters to FP16 "
129+
"(--cast_bf16_to_fp16 flag is enabled)"
130+
)
131+
for name, param in model.named_parameters():
132+
if param.dtype == torch.bfloat16:
133+
if param.max() > torch.finfo(torch.float16).max:
134+
dprint(
135+
f"[WARNING] Casting param {name} to fp16 will truncate the "
136+
"tensor. This may cause accuracy loss. Ignore this warning if "
137+
"this is intended."
138+
)
139+
param.data = param.data.to(dtype=torch.float16)
140+
elif args.cast_fp16_to_bf16:
141+
dprint(
142+
"Casting all FP16 model parameters to BF16 "
143+
"(--cast_fp16_to_bf16 flag is enabled)"
144+
)
145+
for param in model.parameters():
146+
if param.dtype == torch.float16:
147+
param.data = param.data.to(dtype=torch.bfloat16)
148+
149+
123150
def print_model_params(model: nn.Module, args: argparse.Namespace) -> None:
124151
"""Printout model and list of model parameters with related statistics."""
125152

aiu_fms_testing_utils/utils/quantization_setup.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77

88
# Third Party
9+
import torch
910
from torch import nn
1011

1112
# Local Packages
@@ -133,3 +134,35 @@ def select_int8_module(
133134
else:
134135
linear_config = {"linear_type": "torch_linear"}
135136
return linear_config
137+
138+
139+
def validate_quantization(model: nn.Module, args: argparse.Namespace) -> None:
140+
"""Ensure compatibility of FP8 models with device-specific operations."""
141+
142+
has_fp8_weights = False
143+
has_bf16_weights = False
144+
has_fp16_weights = False
145+
for param in model.parameters():
146+
if param.dtype == torch.float8_e4m3fn:
147+
has_fp8_weights = True
148+
elif param.dtype == torch.bfloat16:
149+
has_bf16_weights = True
150+
elif param.dtype == torch.float16:
151+
has_fp16_weights = True
152+
153+
if has_fp8_weights:
154+
if args.is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16:
155+
raise ValueError(
156+
"FP8 checkpoints on AIU with bf16 weights require casting to fp16 "
157+
"using --cast_bf16_to_fp16. Do not use --default_dtype!"
158+
)
159+
elif (
160+
args.device.type == "cuda"
161+
and has_fp16_weights
162+
and not args.cast_fp16_to_bf16
163+
):
164+
raise ValueError(
165+
"FP8 checkpoints on GPU with fp16 weights require casting to bf16 "
166+
"using --cast_fp16_to_bf16. Do not use --default_dtype!"
167+
)
168+

scripts/run_encoders.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
run_encoder_eval_qa,
1717
run_encoder_eval_mlm,
1818
)
19-
from aiu_fms_testing_utils.utils.model_setup import setup_model, print_model_params
19+
from aiu_fms_testing_utils.utils.model_setup import (
20+
setup_model,
21+
print_model_params,
22+
recast_16b
23+
)
2024
from aiu_fms_testing_utils.utils.quantization_setup import (
2125
import_addons,
2226
get_linear_config,
27+
validate_quantization,
2328
)
2429

2530
parser = argparse.ArgumentParser(
@@ -61,9 +66,14 @@
6166
group=distributed.group.WORLD,
6267
linear_config=linear_config,
6368
fused_weights=args.fused_weights,
69+
attn_name="math_fp8",
6470
)
6571

72+
if args.force_16b_dtype:
73+
recast_16b(model, args)
74+
6675
if args.is_quantized:
76+
validate_quantization(model, args)
6777
print_model_params(model, args)
6878

6979
tokenizer = tokenizers.get_tokenizer(args.tokenizer)

0 commit comments

Comments
 (0)