Conversation
mlx_lm/models/ouro.py
Outdated
| for gate in gates[:-1]: | ||
| lambda_i = mx.sigmoid(gate.squeeze(-1)) |
There was a problem hiding this comment.
I would do the sigmoid of the gates on the full vector gates[:-1] and then do the loop.
|
It's a pretty interesting model and very nicely implemented. However, I've generally been quite skeptical of models with early exit as it doesn't play very well with GPUs. In this implementation it's less efficient than just running the model in full for every token since you have to save and evaluate all the hidden states and then decide which to keep. Of course ideally you would try to compute up to the hidden states you actually need, but that's also quite difficult because every time you have to do control flow based on data (the probabilities) you have to stall the GPU. I think we could merge this.. or leave it as an experimental PR.. kind of depends if anyone wants to use this model. |
0632cfe to
9fa47ad
Compare
|
Thanks for the feedback, @awni. Yeah, I did play around a bit with this to see if I could avoid running all loops for the early exits, but decided to ultimately just mirror the reference to keep it clean. I agree with this being experimental, so I'll leave it up to you to decide. |
This adds support for the Ouro family models from ByteDance.
Example
Configuration
Since Ouro is a looped language model they have some additional parameters. These can be adjusted as follows.
Note that models run all (4) UT steps by default.
Benchmarks
Performance benchmarks on Apple M3 Ultra (80 GPU cores, 512GB RAM).
python benchmark.py mlx --contexts 2,4,8,16,32,64,128 --max-tokens 200 <model>mlx-community/Ouro-1.4B-4bit
mlx-community/Ouro-2.6B-4bit