Skip to content

Conversation

@riverlijunjie
Copy link
Contributor

@riverlijunjie riverlijunjie commented Feb 2, 2026

Details:

  • Apply persistent threads pool to dynamic fetch tasks, which can avoid original imbalance workload caused by tons of empty threads
  • Apply it both for qwen3_moe and gpt_oss models
  • E2E Test result:
image
  • Qwen3(128 experts) micro_gemm performance profiling:
image
  • GPT_OSS(32 experts) micro_gemm performance profiling:
image

Note: From the test data shows that this optimization will get more benefit for the model which has a large number of Experts and a single Expert is small, such as qwen3 moe. While small number of experts(gpt_oss) cannot get enough benefit.
So we will set expert_num>=64 to enable this optimization.

Tickets:

To optimize this, to implemented a Persistent Threads strategy with flattened dispatch:

Flattened Grid:
    Instead of a 3D grid (M_tiles, N_tiles, Num_Experts), now use a 2D grid (M_tiles, Total_Compute_Units * Overprovision_Factor). This decouples the number of GPU threads from the specific number of experts or tokens.
Prefix Scan & Dynamic Mapping:
    Inside the kernel, calculate the exact number of work-tiles needed for each expert based on their assigned token count n_array.
    Perform a fast on-chip prefix scan to map specific experts to ranges in a linearized "global tile space".
    The persistent threads loop over this global tile space. If thread T picks tile X, it binary-searches the prefix sum array to identify which Expert owns tile X, and which relative tile index it is.
No Host Synchronization:
    This logic happens entirely on the GPU, avoiding expensive readbacks or host-side scheduling
  Expand the task space to Total_Tiles * M_Blocks.
  Within the Persistent Loop, tile_idx (Expert) and m_block_idx (Row) are dynamically calculated through task_idx.
  All M-Block tasks are now mixed in one huge task pool.
  All Persistent Groups on the GPU will continuously process these tiny tasks like a pipeline without any intermediate
  wave synchronization loss.
  The vast majority of Stalls will be eliminated until the last extremely small tail (<0.1%)
@github-actions github-actions bot added the category: GPU OpenVINO GPU plugin label Feb 2, 2026
@riverlijunjie riverlijunjie changed the title [GPU] optimize moe workload balance [GPU] optimize moe imbalance workload of micro_gemm Feb 2, 2026
@riverlijunjie riverlijunjie requested a review from Copilot February 2, 2026 04:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the MoE (Mixture of Experts) workload imbalance in the micro_gemm kernel by implementing a persistent thread pool strategy. The optimization applies dynamic task fetching to avoid empty threads caused by workload imbalance, targeting both qwen3_moe and gpt_oss models.

Changes:

  • Introduces persistent thread pool with dynamic task distribution to balance workload across GPU threads
  • Implements the optimization conditionally for prefill stage using ENABLE_WORKLOAD_BALANCE flag
  • Refactors debug logging to use configurable trace macros

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_gemm.cl Adds new persistent thread pool implementation for balanced workload distribution
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_gemm_gen_micro.cpp Enables workload balancing for prefill stage with persistent groups and adds num_experts scalar argument
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp Refactors debug logging macros for better configurability
src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_gen_micro.cpp Enables workload balancing with persistent groups and adds num_experts scalar argument for all three GEMM types

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@riverlijunjie riverlijunjie marked this pull request as ready for review February 2, 2026 06:33
@riverlijunjie riverlijunjie requested review from a team as code owners February 2, 2026 06:33
@riverlijunjie riverlijunjie force-pushed the river/moe_balance_workload branch from d7edb25 to bd9c140 Compare February 2, 2026 08:30
…ed state

Expert num will impact workload balance state, if expert num is small, the average number of Tokens allocated to each Expert is large,
and the load imbalance rate caused by random allocation will naturally decrease.
@riverlijunjie riverlijunjie force-pushed the river/moe_balance_workload branch from 7bb8e1a to 8259e2d Compare February 3, 2026 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: GPU OpenVINO GPU plugin

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant