Skip to content

Fix precision in Metal fused attention#3119

Merged
awni merged 1 commit intomainfrom
fix_attn_precision
Feb 10, 2026
Merged

Fix precision in Metal fused attention#3119
awni merged 1 commit intomainfrom
fix_attn_precision

Conversation

@awni
Copy link
Member

@awni awni commented Feb 10, 2026

In our vector and NAX attention we keep the scale factor in fp32. But in the non-NAX fused attention it gets downcast to bf16 which is made worse by the fact that it is multiplied by another scale as well.

It seems to have a real impact on model quality in some cases: ml-explore/mlx-lm#868 (comment)

In terms of performance I think the regression is acceptable as it's less than 1% (at most 0.5%) for all the cases I tried

And in terms of accuracy the difference between the fused attention in bf16 and fp32 with a scaling factor is noticeably lower:

The maximum absolute difference goes down by a factor of 6. For some random inputs with a typical scaling factor based on the head dimension:

Pre: 0.00585938
Post: 0.000976562

@awni awni requested a review from angeloskath February 10, 2026 20:23
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks

@awni awni merged commit 4c86c1e into main Feb 10, 2026
16 checks passed
@awni awni deleted the fix_attn_precision branch February 10, 2026 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants