-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Optimization] Add Fused Triton Kernel for GPT-OSS Router #29237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused Triton kernel for MoE routing in GPT-OSS models to optimize performance. The changes include a new Triton kernel and its integration into the model. My review identified a critical correctness issue: the fused kernel implementation and its usage in gpt_oss.py completely ignore the bias term of the router's linear layer, which will lead to incorrect model outputs. Additionally, I've identified two high-severity issues in the new Triton kernel: the block sizes for the kernel are hardcoded, which can lead to suboptimal performance, and the comments explaining the GEMM logic within the kernel are confusing and contain inaccuracies, which impacts maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
4ef122b to
f3c61a9
Compare
| else: | ||
| g = self.router(x) | ||
| x = self.experts(hidden_states=x, router_logits=g) | ||
| topk_weights, topk_indices = fused_router( | ||
| hidden_states=x, | ||
| router_weights=self.router.weight, | ||
| router_bias=self.router.bias, | ||
| top_k=self.experts_per_token, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you keep the router outside and just write a fused top-k first? Then you won't need to change the interface to self.experts or even the model code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPT-OSS-20B has 32 experts, this model is supposed to get into a cuda kernel, instead of default branch, AFAIK. The intention of this issue is to replace topkGatingSoftmax, right? But if that's the case, I think we should first evaluate the roofline of that kernel.
ZJY0516
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a test for this?
Definitely. |
f3c61a9 to
dc3f820
Compare
Signed-off-by: ijpq <509634578tk@gmail.com>
dc3f820 to
fca484b
Compare
| @@ -174,6 +176,11 @@ def __init__( | |||
| has_bias=True, | |||
| activation="swigluoai", | |||
| is_sequence_parallel=self.is_sequence_parallel, | |||
| custom_routing_function=( | |||
| gpt_oss_custom_routing_function | |||
| if not current_platform.is_rocm() | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just check cuda to help other platforms
| if not current_platform.is_rocm() | |
| if current_platform.is_cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you consolidate these tests into one file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack
|
|
||
| topk_padded = triton.next_power_of_2(topk) | ||
|
|
||
| grid = (M,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to tune the kernel at all? I haven't seen a benchmark reporting perf yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack. report updated.
I update the benchmark test. Any advice? Since there are some changes to my local hardware, roofline analysis takes few days to go. But Intuitively speaking, a compiler product is unable to outperform than optimized cuda kernel. |
| @pytest.mark.parametrize("M", [1, 32, 128, 2048]) | ||
| @pytest.mark.parametrize("N", [32, 65, 128]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 3, 4, 5]) | ||
| def test_fused_router(M, N, topk): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip if current_platform is not cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack
| @pytest.mark.parametrize("num_tokens", [10, 128, 1024]) | ||
| @pytest.mark.parametrize("num_experts", [32, 65, 128]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 3, 4, 5]) | ||
| def test_routing_consistency(num_tokens, num_experts, topk): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack
| ): | ||
| token_idx = tl.program_id(0) | ||
|
|
||
| offs = tl.arange(0, BLOCK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it could be better performance wise to use tl.range / instead of tl.arange (for better pipelining) with the num_steps as a parameter for which best value can be found via auto-tune.
Example here using num_stages=num_stages https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html . You could even try to set warp_specialize=True with autotune to see if that impacts perf further...
|
Thank you for the correction. @shaginhekvs I just did roofline report. This kernel still has a lot of room for optimization. I'll take another look at Triton, and I should be able to provide the optimized results today.
|
- split two kernels, in case renorm or not - add online softmax - unroll along M Signed-off-by: ijpq <509634578tk@gmail.com>
5af77ad to
66e6711
Compare
|
TL;DR Achieved better flop/s and get rid of mem bound in roofline analysis(renorm enabled), in commit : 66e6711
collected by I'm working on these things:
I’ll try to get it done by tomorrow. |


The output of
python collect_env.pyPurpose
Resolves: #28986
Test Plan
compare results with torch and intact impl.
Test Result
compare results
The output of
pytest test_gpt_oss_fused_router.pyThe output of
pytest test_routing_consistency.pybenchmark:
baseline:
fused:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.