Skip to content

Conversation

@ijpq
Copy link

@ijpq ijpq commented Nov 22, 2025

The output of python collect_env.py
Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 24.04.3 LTS (x86_64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version                : Could not collect
CMake version                : version 3.28.3
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.9.0+cu128
Is debug build               : False
CUDA used to build PyTorch   : 12.8
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.11 | packaged by conda-forge | (main, Jun  4 2025, 14:45:31) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-5.4.0-216-generic-x86_64-with-glibc2.39

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 12.8.93
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version        : 570.153.02
cuDNN version                : Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.10.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             64
On-line CPU(s) list:                0-63
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7343 16-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU(s) scaling MHz:                 86%
CPU max MHz:                        3200.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           6387.85
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                     AMD-V
L1d cache:                          1 MiB (32 instances)
L1i cache:                          1 MiB (32 instances)
L2 cache:                           16 MiB (32 instances)
L3 cache:                           256 MiB (8 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-15,32-47
NUMA node1 CPU(s):                  16-31,48-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.5.2
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.16.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cutlass-dsl==4.3.0
[pip3] nvidia-ml-py==13.580.82
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.3.20
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.9.0
[pip3] torchaudio==2.9.0
[pip3] torchvision==0.24.0
[pip3] transformers==4.57.1
[pip3] triton==3.5.0
[conda] flashinfer-python         0.5.2                    pypi_0    pypi
[conda] numpy                     2.2.6                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.8.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.8.90                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.8.93                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.8.90                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.10.2.21                pypi_0    pypi
[conda] nvidia-cudnn-frontend     1.16.0                   pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.3.83                pypi_0    pypi
[conda] nvidia-cufile-cu12        1.13.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.9.90                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.3.90                pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.8.93                pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.7.1                    pypi_0    pypi
[conda] nvidia-cutlass-dsl        4.3.0                    pypi_0    pypi
[conda] nvidia-ml-py              13.580.82                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.27.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.8.93                  pypi_0    pypi
[conda] nvidia-nvshmem-cu12       3.3.20                   pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.8.90                  pypi_0    pypi
[conda] pyzmq                     27.1.0                   pypi_0    pypi
[conda] torch                     2.9.0                    pypi_0    pypi
[conda] torchaudio                2.9.0                    pypi_0    pypi
[conda] torchvision               0.24.0                   pypi_0    pypi
[conda] transformers              4.57.1                   pypi_0    pypi
[conda] triton                    3.5.0                    pypi_0    pypi

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.1.dev11359+gf3c61a92a.d20251124 (git sha: f3c61a92a, date: 20251124)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
  	^[[4mGPU0	NIC0	CPU Affinity	NUMA Affinity	GPU NUMA ID^[[0m
GPU0	 X 	SYS	16-31,48-63	1		N/A
NIC0	SYS	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx4_0

==============================
     Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=D.cca8b812a3ffc614142fc7b8d7c8b2bc17a7a24d02ca5110609b5ae90135abe8/gpu=7
NVIDIA_REQUIRE_CUDA=cuda>=12.8 brand=unknown,driver>=470,driver<471 brand=grid,driver>=470,driver<471 brand=tesla,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=vapps,driver>=470,driver<471 brand=vpc,driver>=470,driver<471 brand=vcs,driver>=470,driver<471 brand=vws,driver>=470,driver<471 brand=cloudgaming,driver>=470,driver<471 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=560,driver<561 brand=grid,driver>=560,driver<561 brand=tesla,driver>=560,driver<561 brand=nvidia,driver>=560,driver<561 brand=quadro,driver>=560,driver<561 brand=quadrortx,driver>=560,driver<561 brand=nvidiartx,driver>=560,driver<561 brand=vapps,driver>=560,driver<561 brand=vpc,driver>=560,driver<561 brand=vcs,driver>=560,driver<561 brand=vws,driver>=560,driver<561 brand=cloudgaming,driver>=560,driver<561 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566
NCCL_VERSION=2.25.1-1
NVIDIA_DRIVER_CAPABILITIES=all
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.8.1
LD_LIBRARY_PATH=/usr/local/cuda/lib64
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1


Purpose

Resolves: #28986

  1. add a fused topk+softmax triton kernel for gptoss and others.
  2. minimize modification in model's interface, by indicating custom_routing_function.

Test Plan

compare results with torch and intact impl.

  • It's necessary to roofline. to prove performance is better than existed cuda kernel, including num experts is power of 2 or not.

Test Result

compare results
The output of pytest test_gpt_oss_fused_router.py
============================= test session starts ==============================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0 -- /workspace/vllm/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /workspace/vllm
configfile: pyproject.toml
plugins: anyio-4.11.0
collecting ... collected 60 items

tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-32-1] PASSED [  1%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-32-32] PASSED [  3%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-32-128] PASSED [  5%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-32-2048] PASSED [  6%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-65-1] PASSED [  8%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-65-32] PASSED [ 10%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-65-128] PASSED [ 11%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-65-2048] PASSED [ 13%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-128-1] PASSED [ 15%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-128-32] PASSED [ 16%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-128-128] PASSED [ 18%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[1-128-2048] PASSED [ 20%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-32-1] PASSED [ 21%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-32-32] PASSED [ 23%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-32-128] PASSED [ 25%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-32-2048] PASSED [ 26%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-65-1] PASSED [ 28%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-65-32] PASSED [ 30%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-65-128] PASSED [ 31%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-65-2048] PASSED [ 33%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-128-1] PASSED [ 35%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-128-32] PASSED [ 36%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-128-128] PASSED [ 38%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[2-128-2048] PASSED [ 40%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-32-1] PASSED [ 41%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-32-32] PASSED [ 43%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-32-128] PASSED [ 45%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-32-2048] PASSED [ 46%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-65-1] PASSED [ 48%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-65-32] PASSED [ 50%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-65-128] PASSED [ 51%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-65-2048] PASSED [ 53%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-128-1] PASSED [ 55%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-128-32] PASSED [ 56%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-128-128] PASSED [ 58%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[3-128-2048] PASSED [ 60%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-32-1] PASSED [ 61%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-32-32] PASSED [ 63%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-32-128] PASSED [ 65%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-32-2048] PASSED [ 66%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-65-1] PASSED [ 68%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-65-32] PASSED [ 70%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-65-128] PASSED [ 71%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-65-2048] PASSED [ 73%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-128-1] PASSED [ 75%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-128-32] PASSED [ 76%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-128-128] PASSED [ 78%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[4-128-2048] PASSED [ 80%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-32-1] PASSED [ 81%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-32-32] PASSED [ 83%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-32-128] PASSED [ 85%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-32-2048] PASSED [ 86%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-65-1] PASSED [ 88%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-65-32] PASSED [ 90%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-65-128] PASSED [ 91%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-65-2048] PASSED [ 93%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-128-1] PASSED [ 95%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-128-32] PASSED [ 96%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-128-128] PASSED [ 98%]
tests/kernels/moe/test_gpt_oss_fused_router.py::test_fused_router[5-128-2048] PASSED [100%]

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 60 passed, 2 warnings in 14.76s ========================

The output of pytest test_routing_consistency.py
============================= test session starts ==============================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0 -- /workspace/vllm/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /workspace/vllm
configfile: pyproject.toml
plugins: anyio-4.11.0
collecting ... collected 45 items

tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-32-10] PASSED [  2%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-32-128] PASSED [  4%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-32-1024] PASSED [  6%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-65-10] PASSED [  8%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-65-128] PASSED [ 11%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-65-1024] PASSED [ 13%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-128-10] PASSED [ 15%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-128-128] PASSED [ 17%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[1-128-1024] PASSED [ 20%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-32-10] PASSED [ 22%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-32-128] PASSED [ 24%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-32-1024] PASSED [ 26%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-65-10] PASSED [ 28%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-65-128] PASSED [ 31%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-65-1024] PASSED [ 33%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-128-10] PASSED [ 35%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-128-128] PASSED [ 37%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[2-128-1024] PASSED [ 40%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-32-10] PASSED [ 42%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-32-128] PASSED [ 44%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-32-1024] PASSED [ 46%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-65-10] PASSED [ 48%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-65-128] PASSED [ 51%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-65-1024] PASSED [ 53%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-128-10] PASSED [ 55%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-128-128] PASSED [ 57%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[3-128-1024] PASSED [ 60%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-32-10] PASSED [ 62%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-32-128] PASSED [ 64%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-32-1024] PASSED [ 66%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-65-10] PASSED [ 68%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-65-128] PASSED [ 71%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-65-1024] PASSED [ 73%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-128-10] PASSED [ 75%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-128-128] PASSED [ 77%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[4-128-1024] PASSED [ 80%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-32-10] PASSED [ 82%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-32-128] PASSED [ 84%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-32-1024] PASSED [ 86%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-65-10] PASSED [ 88%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-65-128] PASSED [ 91%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-65-1024] PASSED [ 93%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-128-10] PASSED [ 95%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-128-128] PASSED [ 97%]
tests/kernels/moe/test_gpt_oss_routing_consistency.py::test_routing_consistency[5-128-1024] PASSED [100%]

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 45 passed, 2 warnings in 10.73s ========================


benchmark:

 VLLM_TORCH_PROFILER_DIR=./traces nsys profile  \
    --trace-fork-before-exec=true \
    --cuda-graph-trace=node \
vllm bench latency \
    --model openai/gpt-oss-20b \
    --num-iters-warmup 30 \
    --num-iters 10 \
    --batch-size 16 \
    --input-len 1024 \
    --output-len 256

baseline:

Avg prompt throughput: 1820.3 tokens/s, Avg generation throughput: 742.0 tokens/s, Running: 3 reqs, Waiting: 13 reqs, GPU KV cache usage: 1.8%, Prefix cache hit rate: 0.0%
Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
0.3%	492.517 ms	198912	2.476 μs	2.432 μs	2.368 μs	3.776 μs	180 ns	void vllm::moe::topkGatingSoftmax<(int)8, (int)32, (int)4, (int)16, (int)32, int, __nv_bfloat16>(const T7 *, const bool *, float *, int, T6 *, int *, int, int, int, bool)

fused:

Avg prompt throughput: 2232.8 tokens/s, Avg generation throughput: 699.7 tokens/s, Running: 7 reqs, Waiting: 9 reqs, GPU KV cache usage: 3.5%, Prefix cache hit rate: 0.0%
Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
0.2%	390.512 ms	200936	1.943 μs	1.760 μs	1.696 μs	27.872 μs	1.407 μs	_topk_softmax_kernel

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused Triton kernel for MoE routing in GPT-OSS models to optimize performance. The changes include a new Triton kernel and its integration into the model. My review identified a critical correctness issue: the fused kernel implementation and its usage in gpt_oss.py completely ignore the bias term of the router's linear layer, which will lead to incorrect model outputs. Additionally, I've identified two high-severity issues in the new Triton kernel: the block sizes for the kernel are hardcoded, which can lead to suboptimal performance, and the comments explaining the GEMM logic within the kernel are confusing and contain inaccuracies, which impacts maintainability.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@ijpq ijpq force-pushed the ijpq/fused_router_gptoss branch from 4ef122b to f3c61a9 Compare November 22, 2025 12:08
Comment on lines 188 to 195
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
topk_weights, topk_indices = fused_router(
hidden_states=x,
router_weights=self.router.weight,
router_bias=self.router.bias,
top_k=self.experts_per_token,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you keep the router outside and just write a fused top-k first? Then you won't need to change the interface to self.experts or even the model code

Copy link
Author

@ijpq ijpq Nov 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT-OSS-20B has 32 experts, this model is supposed to get into a cuda kernel, instead of default branch, AFAIK. The intention of this issue is to replace topkGatingSoftmax, right? But if that's the case, I think we should first evaluate the roofline of that kernel.

Copy link
Contributor

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a test for this?

@ijpq
Copy link
Author

ijpq commented Nov 23, 2025

Could you please add a test for this?

Definitely.

Signed-off-by: ijpq <509634578tk@gmail.com>
@ijpq ijpq force-pushed the ijpq/fused_router_gptoss branch from dc3f820 to fca484b Compare November 24, 2025 16:28
@ijpq ijpq requested review from ZJY0516 and mgoin November 25, 2025 04:46
@@ -174,6 +176,11 @@ def __init__(
has_bias=True,
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel,
custom_routing_function=(
gpt_oss_custom_routing_function
if not current_platform.is_rocm()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just check cuda to help other platforms

Suggested change
if not current_platform.is_rocm()
if current_platform.is_cuda()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you consolidate these tests into one file?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack


topk_padded = triton.next_power_of_2(topk)

grid = (M,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to tune the kernel at all? I haven't seen a benchmark reporting perf yet

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack. report updated.

@ijpq
Copy link
Author

ijpq commented Nov 27, 2025

I intend to reorganize my thoughts, verify the latest updates in the moe section, and push again once I have a better understanding of the moe code. Really thanks for your review and hints @mgoin @ZJY0516

@ijpq
Copy link
Author

ijpq commented Nov 27, 2025

I intend to reorganize my thoughts, verify the latest updates in the moe section, and push again once I have a better understanding of the moe code. Really thanks for your review and hints @mgoin @ZJY0516

I update the benchmark test. Any advice? Since there are some changes to my local hardware, roofline analysis takes few days to go. But Intuitively speaking, a compiler product is unable to outperform than optimized cuda kernel.

@pytest.mark.parametrize("M", [1, 32, 128, 2048])
@pytest.mark.parametrize("N", [32, 65, 128])
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
def test_fused_router(M, N, topk):
Copy link
Contributor

@ElizaWszola ElizaWszola Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip if current_platform is not cuda?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

@pytest.mark.parametrize("num_tokens", [10, 128, 1024])
@pytest.mark.parametrize("num_experts", [32, 65, 128])
@pytest.mark.parametrize("topk", [1, 2, 3, 4, 5])
def test_routing_consistency(num_tokens, num_experts, topk):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

):
token_idx = tl.program_id(0)

offs = tl.arange(0, BLOCK_N)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could be better performance wise to use tl.range / instead of tl.arange (for better pipelining) with the num_steps as a parameter for which best value can be found via auto-tune.
Example here using num_stages=num_stages https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html . You could even try to set warp_specialize=True with autotune to see if that impacts perf further...

@ijpq
Copy link
Author

ijpq commented Nov 30, 2025

Thank you for the correction. @shaginhekvs

I just did roofline report. This kernel still has a lot of room for optimization. I'll take another look at Triton, and I should be able to provide the optimized results today.

image

- split two kernels, in case renorm or not
- add online softmax
- unroll along M

Signed-off-by: ijpq <509634578tk@gmail.com>
@ijpq ijpq force-pushed the ijpq/fused_router_gptoss branch from 5af77ad to 66e6711 Compare November 30, 2025 15:48
@ijpq
Copy link
Author

ijpq commented Nov 30, 2025

TL;DR
Achieved point moved to the upper right.

Achieved better flop/s and get rid of mem bound in roofline analysis(renorm enabled), in commit : 66e6711

image

collected by

ncu -f --set full --target-processes all -o triton_topk_profile   --launch-skip 100 --launch-count 5   --kernel-name _topk_softmax_renorm_kernel vllm bench latency     --model openai/gpt-oss-20b     --num-iters-warmup 5     --num-iters 1     --batch-size 16     --input-len 1024     --output-len 256

I'm working on these things:

  • unroll along in M axis for each program.
  • better memory access solution.
  • test triton autotune.

I’ll try to get it done by tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models

Projects

Status: To Triage

Development

Successfully merging this pull request may close these issues.

[Feature]: Fused Kernel for GPT-OSS Router

5 participants