-
Notifications
You must be signed in to change notification settings - Fork 3k
[GPU] optimize moe imbalance workload of micro_gemm #33922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[GPU] optimize moe imbalance workload of micro_gemm #33922
Conversation
To optimize this, to implemented a Persistent Threads strategy with flattened dispatch:
Flattened Grid:
Instead of a 3D grid (M_tiles, N_tiles, Num_Experts), now use a 2D grid (M_tiles, Total_Compute_Units * Overprovision_Factor). This decouples the number of GPU threads from the specific number of experts or tokens.
Prefix Scan & Dynamic Mapping:
Inside the kernel, calculate the exact number of work-tiles needed for each expert based on their assigned token count n_array.
Perform a fast on-chip prefix scan to map specific experts to ranges in a linearized "global tile space".
The persistent threads loop over this global tile space. If thread T picks tile X, it binary-searches the prefix sum array to identify which Expert owns tile X, and which relative tile index it is.
No Host Synchronization:
This logic happens entirely on the GPU, avoiding expensive readbacks or host-side scheduling
Expand the task space to Total_Tiles * M_Blocks. Within the Persistent Loop, tile_idx (Expert) and m_block_idx (Row) are dynamically calculated through task_idx. All M-Block tasks are now mixed in one huge task pool. All Persistent Groups on the GPU will continuously process these tiny tasks like a pipeline without any intermediate wave synchronization loss. The vast majority of Stalls will be eliminated until the last extremely small tail (<0.1%)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR optimizes the MoE (Mixture of Experts) workload imbalance in the micro_gemm kernel by implementing a persistent thread pool strategy. The optimization applies dynamic task fetching to avoid empty threads caused by workload imbalance, targeting both qwen3_moe and gpt_oss models.
Changes:
- Introduces persistent thread pool with dynamic task distribution to balance workload across GPU threads
- Implements the optimization conditionally for prefill stage using
ENABLE_WORKLOAD_BALANCEflag - Refactors debug logging to use configurable trace macros
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_gemm.cl | Adds new persistent thread pool implementation for balanced workload distribution |
| src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_gemm_gen_micro.cpp | Enables workload balancing for prefill stage with persistent groups and adds num_experts scalar argument |
| src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp | Refactors debug logging macros for better configurability |
| src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_gen_micro.cpp | Enables workload balancing with persistent groups and adds num_experts scalar argument for all three GEMM types |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_gemm_gen_micro.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_gen_micro.cpp
Outdated
Show resolved
Hide resolved
d7edb25 to
bd9c140
Compare
…ed state Expert num will impact workload balance state, if expert num is small, the average number of Tokens allocated to each Expert is large, and the load imbalance rate caused by random allocation will naturally decrease.
7bb8e1a to
8259e2d
Compare
Details:
Note: From the test data shows that this optimization will get more benefit for the model which has a large number of Experts and a single Expert is small, such as qwen3 moe. While small number of experts(gpt_oss) cannot get enough benefit.
So we will set expert_num>=64 to enable this optimization.
Tickets: