-
Notifications
You must be signed in to change notification settings - Fork 418
LongCat MLA #868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
LongCat MLA #868
Conversation
| # Pre-scale Q: the absorbed nope/pe split routes scale | ||
| # through different precisions in the Steel SDPA kernel, | ||
| # causing drift on long sequences. Pre-scaling avoids this. | ||
| k = self.embed_q(kv_latent, transpose=False) | ||
| v = self.unembed_out(kv_latent) | ||
| q_nope = q_nope * self.scale | ||
| output = scaled_dot_product_attention( | ||
| q_nope, k, v, cache=cache, scale=1.0, mask=pe_scores | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, rather than the workaround, would the proper fix be to address the precision mismatch in mlx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed.. this is quite curious. You notice it is strictly better to scale q_nope prior to the attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kernelpool I think we can remove that change in favor of an improvement in MLX like you suggested: ml-explore/mlx#3119
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scaling q_nope prior to attention is critical it seems. I'm still getting the same result with ml-explore/mlx#3119 (the M3 Ultra is using the NAX kernel). I think the issue is that pe_scores (in Python) and q_nope (in kernel) are still scaled at different precisions (the pe_scores mask is also bf16). The reason L == 1 is not a problem in this case is because the fallback is used instead of an optimized kernel (head_dim=512), but this should probably also be using pre-scaling for consistency. I think it's likely that there are similar issues (of varying degree) in the other MLA implementations since they use the same pattern.
EDIT: It's possible to also use a concatenation approach (e.g. this is what llama.cpp does), where nope and pe are concatenated along head_dim, but I ran some tests and its about 3% slower (at 4k tokens)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(the M3 Ultra is using the NAX kernel).
No it shouldn't be.. that would only be for M5 and up.
I'm still getting the same result
Can you share a prompt that reproduces the bad result when you don't use pre-scaling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With prescale:
python test_server_multiturn.py
Waiting for server to start...
=== Turn 1: hello ===
Response: Hello! How can I help you today?
=== Turn 2: how are you? ===
Response: I'm doing great, thank you for asking! How are you doing today?
=== Turn 3: what is 2+2? ===
Response: 2 + 2 equals 4.
=== Analysis ===
OK: Turn 3 correctly answered 2+2=4
Without prescale:
mlx-lm % python test_server_multiturn.py
Waiting for server to start...
=== Turn 1: hello ===
Response: Hello! How can I help you today?
=== Turn 2: how are you? ===
Response: I'm an AI assistant - a language model developed by BAAI. How can I help you today?
=== Turn 3: what is 2+2? ===
Response: I'm an AI assistant developed by BAAI (Beijing Academy of Artificial Intelligence). How can I help you today?
=== Analysis ===
BUG: Turn 3 did not answer 2+2=4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a bug to be honest.. but I don't know where it is or what it could be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mx.contiguous(q_nope) also resolve the issue so it's likely related to the fact that q_nope is not contiguous and something funky happens in that case.
This adds MLA absorption to the LongCat models.
Note: The pre-scale Q for prefill is necessary otherwise the model (LongCat-Flash-Lite-4bit, in particular) ends up with weird behavior, e.g. when doing a simple "hello", "how are you?", "list the current directory" sequence in OpenCode (in an empty directory), the model would sometimes repeat the response of the previous question (e.g. respond to "how are you?" again after the user prompts "list the current directory")
mlx-community/LongCat-Flash-Thinking-2601-4bit
mlx-community/LongCat-Flash-Lite-4bit