Skip to content

Commit 1a77f63

Browse files
Merge pull request #75 from andrea-fasoli/fp8_fixes
Fix mask creation for QA
2 parents 8d703be + 3564da2 commit 1a77f63

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

aiu_fms_testing_utils/utils/args_parsing.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,18 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
1212
args_model_loading.add_argument(
1313
"--architecture",
1414
type=str,
15-
help="The model architecture to benchmark",
15+
help="The model architecture to benchmark.",
1616
)
1717
args_model_loading.add_argument(
1818
"--variant",
1919
type=str,
2020
default=None,
21-
help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.",
21+
help="The model variant (configuration) to benchmark (e.g., 7b, 13b, 70b).",
2222
)
2323
args_model_loading.add_argument(
2424
"--model_path",
2525
type=str,
26-
help=(
27-
"Path to the directory containing LLaMa weights "
28-
"(.pth files sharded by tensor parallel rank, not HF weights)"
29-
),
26+
help="Path to the directory containing the model checkpoint(s).",
3027
)
3128
args_model_loading.add_argument(
3229
"--model_source",
@@ -36,9 +33,7 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
3633
args_model_loading.add_argument(
3734
"--unfuse_weights",
3835
action="store_true",
39-
help=(
40-
"If set to True, this will unfuse any fused weight modules"
41-
),
36+
help="If True, this will unfuse any fused weight modules.",
4237
)
4338
args_model_loading.add_argument(
4439
"--default_dtype",
@@ -47,23 +42,23 @@ def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
4742
choices=["bf16", "fp16", "fp32"],
4843
help=(
4944
"If set to one of the choices, overrides the model checkpoint "
50-
"weight format by setting the default pytorch format"
45+
"weight format by setting the default pytorch format."
5146
),
5247
)
5348
parser.add_argument(
5449
"--cast_bf16_to_fp16",
5550
action="store_true",
5651
help=(
5752
"If set, cast any bf16 weights in the model to fp16 for AIU compiler. "
58-
"Doesn't touch fp32 or quantized"
53+
"Doesn't touch fp32 or quantized."
5954
)
6055
)
6156
parser.add_argument(
6257
"--cast_fp16_to_bf16",
6358
action="store_true",
6459
help=(
6560
"If set, cast any fp16 weights in the model to bf16 for GPU. "
66-
"Doesn't touch fp32 or quantized"
61+
"Doesn't touch fp32 or quantized."
6762
)
6863
)
6964

aiu_fms_testing_utils/utils/encoders_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,13 @@ def convert_batch_to_fms_style(
160160
self,
161161
batch: dict[str, torch.Tensor],
162162
) -> dict[str, torch.Tensor]:
163-
"""FMS uses a different standard than HF for encoder inputs."""
163+
"""FMS uses a different standard than HF for encoder inputs.
164164
165-
return {'x': batch['input_ids'], 'mask': batch['attention_mask']}
165+
The mask is also handled differently in FMS: it is correctly processed by SDPA
166+
only if provided as boolean. A floating binary mask would not be converted.
167+
"""
168+
169+
return {'x': batch['input_ids'], 'mask': batch['attention_mask'].to(torch.bool)}
166170

167171
def process_eval_set(self) -> None:
168172
"""Pre-process evaluation dataset for QuestionAnswering task."""
@@ -210,7 +214,7 @@ def process_eval_set(self) -> None:
210214
f"Using max_prompt_length={model_max_length} instead."
211215
)
212216
self.max_prompt_length = min(
213-
args.max_seq_length,
217+
args.max_prompt_length,
214218
model_max_length,
215219
)
216220

@@ -593,7 +597,8 @@ def run_evaluation(self) -> None:
593597
all_end_logits = []
594598
for step, batch in enumerate(eval_dataloader):
595599
with torch.no_grad():
596-
dprint(f"Step {step + 1} / {len(eval_dataloader)}")
600+
if args.verbose:
601+
dprint(f"Step {step + 1} / {len(eval_dataloader)}")
597602
batch = self.convert_batch_to_fms_style(batch)
598603
batch = move_to_device(batch, args.device)
599604
start_logits, end_logits = self.model(**batch)

aiu_fms_testing_utils/utils/model_setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def print_model_params(model: nn.Module, args: argparse.Namespace) -> None:
152152

153153
if args.verbose:
154154
dprint("="*60 + "\n")
155-
dprint("\n".join(
156-
f"{k:80} {str(list(v.size())):15} {str(v.dtype):18} {str(v.device):10} "
155+
dprint("\n" + "\n".join(
156+
f"{k:70} {str(list(v.size())):15} {str(v.dtype):20} {str(v.device):10} "
157157
f"{v.float().min().item():12.4f} {v.float().max().item():12.4f}"
158158
for k,v in model.state_dict().items()
159159
))

scripts/run_encoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Third Party
66
from fms.models import get_model
77
from fms.models.roberta import RoBERTaForQuestionAnswering, RoBERTa
8+
from fms.models.hf.roberta.modeling_roberta_hf import HFAdaptedRoBERTaForMaskedLM
89
from fms.utils import tokenizers
910
from torch import distributed, set_grad_enabled
1011

@@ -66,7 +67,6 @@
6667
group=distributed.group.WORLD,
6768
linear_config=linear_config,
6869
fused_weights=args.fused_weights,
69-
attn_name="math_fp8",
7070
)
7171

7272
if args.force_16b_dtype:
@@ -100,7 +100,8 @@
100100

101101
if isinstance(model, RoBERTaForQuestionAnswering):
102102
run_encoder_eval_qa(model, tokenizer, args)
103-
elif isinstance(model, RoBERTa): # basic MaskedLM downstream task
103+
elif isinstance(model, RoBERTa) or isinstance(model, HFAdaptedRoBERTaForMaskedLM):
104+
# basic MaskedLM downstream task
104105
run_encoder_eval_mlm(model, tokenizer, args)
105106

106107
if args.distributed:

0 commit comments

Comments
 (0)