Skip to content

Conversation

@kernelpool
Copy link
Contributor

@kernelpool kernelpool commented Feb 10, 2026

This adds MLA absorption to the LongCat models.

Note: The pre-scale Q for prefill is necessary otherwise the model (LongCat-Flash-Lite-4bit, in particular) ends up with weird behavior, e.g. when doing a simple "hello", "how are you?", "list the current directory" sequence in OpenCode (in an empty directory), the model would sometimes repeat the response of the previous question (e.g. respond to "how are you?" again after the user prompts "list the current directory")

mlx-community/LongCat-Flash-Thinking-2601-4bit

Context Original Prompt MLA Prompt Prompt Δ Original Gen MLA Gen Gen Δ Original Mem MLA Mem Mem Δ
1k 152 168 +10% 26.9 26.6 -1% 318.6 GB 316.5 GB -1%
2k 226 232 +3% 23.5 26.7 +14% 322.0 GB 317.7 GB -1%
4k 249 255 +2% 19.8 26.3 +33% 327.2 GB 319.0 GB -3%
8k 248 254 +2% 15.9 25.2 +58% 337.3 GB 321.0 GB -5%
16k 222 228 +3% 11.2 23.2 +107% 356.9 GB 325.4 GB -9%
32k 171 177 +3% 6.7 19.6 +193% 399.5 GB 335.6 GB -16%
64k - 119 - - 15.1 - - 356.1 GB -
128k - 70 - - 10.2 - - 395.5 GB -

mlx-community/LongCat-Flash-Lite-4bit

Context Original Prompt MLA Prompt Prompt Δ Original Gen MLA Gen Gen Δ Original Mem MLA Mem Mem Δ
1k 1,192 1,243 +4% 99.7 94.3 -5% 39.8 GB 39.3 GB -1%
2k 1,601 1,670 +4% 89.4 92.8 +4% 41.1 GB 40.0 GB -3%
4k 1,689 1,755 +4% 74.5 92.3 +24% 42.7 GB 40.6 GB -5%
8k 1,594 1,636 +3% 56.3 84.2 +50% 45.9 GB 41.6 GB -9%
16k 1,312 1,347 +3% 39.1 80.2 +105% 51.9 GB 44.0 GB -15%
32k 926 963 +4% 23.2 67.8 +192% 65.4 GB 49.3 GB -25%
64k 546 580 +6% 12.9 47.1 +265% 93.4 GB 60.1 GB -36%
128k 291 315 +8% 6.8 31.3 +360% 147.9 GB 80.1 GB -46%

Comment on lines +173 to +181
# Pre-scale Q: the absorbed nope/pe split routes scale
# through different precisions in the Steel SDPA kernel,
# causing drift on long sequences. Pre-scaling avoids this.
k = self.embed_q(kv_latent, transpose=False)
v = self.unembed_out(kv_latent)
q_nope = q_nope * self.scale
output = scaled_dot_product_attention(
q_nope, k, v, cache=cache, scale=1.0, mask=pe_scores
)
Copy link
Contributor Author

@kernelpool kernelpool Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, rather than the workaround, would the proper fix be to address the precision mismatch in mlx?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed.. this is quite curious. You notice it is strictly better to scale q_nope prior to the attention?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kernelpool I think we can remove that change in favor of an improvement in MLX like you suggested: ml-explore/mlx#3119

Copy link
Contributor Author

@kernelpool kernelpool Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaling q_nope prior to attention is critical it seems. I'm still getting the same result with ml-explore/mlx#3119 (the M3 Ultra is using the NAX kernel). I think the issue is that pe_scores (in Python) and q_nope (in kernel) are still scaled at different precisions (the pe_scores mask is also bf16). The reason L == 1 is not a problem in this case is because the fallback is used instead of an optimized kernel (head_dim=512), but this should probably also be using pre-scaling for consistency. I think it's likely that there are similar issues (of varying degree) in the other MLA implementations since they use the same pattern.

EDIT: It's possible to also use a concatenation approach (e.g. this is what llama.cpp does), where nope and pe are concatenated along head_dim, but I ran some tests and its about 3% slower (at 4k tokens)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the M3 Ultra is using the NAX kernel).

No it shouldn't be.. that would only be for M5 and up.

I'm still getting the same result

Can you share a prompt that reproduces the bad result when you don't use pre-scaling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_server_multiturn.py

With prescale:

python test_server_multiturn.py              
Waiting for server to start...
=== Turn 1: hello ===
Response: Hello! How can I help you today?

=== Turn 2: how are you? ===
Response: I'm doing great, thank you for asking! How are you doing today?

=== Turn 3: what is 2+2? ===
Response: 2 + 2 equals 4.

=== Analysis ===
OK: Turn 3 correctly answered 2+2=4

Without prescale:

mlx-lm % python test_server_multiturn.py
Waiting for server to start...
=== Turn 1: hello ===
Response: Hello! How can I help you today?

=== Turn 2: how are you? ===
Response: I'm an AI assistant - a language model developed by BAAI. How can I help you today?

=== Turn 3: what is 2+2? ===
Response: I'm an AI assistant developed by BAAI (Beijing Academy of Artificial Intelligence). How can I help you today?

=== Analysis ===
BUG: Turn 3 did not answer 2+2=4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a bug to be honest.. but I don't know where it is or what it could be.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mx.contiguous(q_nope) also resolve the issue so it's likely related to the fact that q_nope is not contiguous and something funky happens in that case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants