-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Closed
Labels
bugSomething isn't workingSomething isn't workingllamaRelated to Llama modelsRelated to Llama models
Description
Your current environment
The output of python collect_env.py
==============================
System Info
==============================
OS : CentOS Stream 9 (x86_64)
GCC version : (GCC) 11.5.0 20240719 (Red Hat 11.5.0-7)
Clang version : Could not collect
CMake version : version 4.0.3
Libc version : glibc-2.34
==============================
PyTorch Info
==============================
PyTorch version : 2.7.1+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.12.10 (main, May 9 2025, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-5)] (64-bit runtime)
Python platform : Linux-6.4.3-0_fbk20_zion_2830_g3e5ab162667d-x86_64-with-glibc2.34
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : 12.8.93
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration :
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100
Nvidia driver version : 535.154.05
cuDNN version : Probably one of the following:
/usr/lib64/libcudnn.so.9.5.1
/usr/lib64/libcudnn_adv.so.9.5.1
/usr/lib64/libcudnn_cnn.so.9.5.1
/usr/lib64/libcudnn_engines_precompiled.so.9.5.1
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib64/libcudnn_graph.so.9.5.1
/usr/lib64/libcudnn_heuristic.so.9.5.1
/usr/lib64/libcudnn_ops.so.9.5.1
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 83%
CPU max MHz: 3707.8120
CPU min MHz: 1500.0000
BogoMIPS: 4792.60
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95,192-287
NUMA node1 CPU(s): 96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable: eIBRS with unprivileged eBPF
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-cufile-cu12==1.13.0.11
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pyzmq==27.0.0
[pip3] sentence-transformers==3.2.1
[pip3] torch==2.7.1+cu128
[pip3] torchaudio==2.7.1+cu128
[pip3] torchvision==0.22.1+cu128
[pip3] transformers==4.53.2
[pip3] transformers-stream-generator==0.0.5
[pip3] triton==3.3.1
[pip3] tritonclient==2.51.0
[pip3] vector-quantize-pytorch==1.21.2
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.9.2rc2.dev304+g28a6d5423 (git sha: 28a6d5423)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 SYS SYS SYS SYS 0-95,192-287 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 PHB PHB SYS SYS 0-95,192-287 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 SYS SYS SYS SYS 0-95,192-287 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 SYS SYS SYS SYS 0-95,192-287 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS SYS SYS 96-191,288-383 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS SYS SYS 96-191,288-383 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS PHB PHB 96-191,288-383 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS SYS SYS 96-191,288-383 1 N/A
NIC0 SYS PHB SYS SYS SYS SYS SYS SYS X PIX SYS SYS
NIC1 SYS PHB SYS SYS SYS SYS SYS SYS PIX X SYS SYS
NIC2 SYS SYS SYS SYS SYS SYS PHB SYS SYS SYS X PIX
NIC3 SYS SYS SYS SYS SYS SYS PHB SYS SYS SYS PIX X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
==============================
Environment Variables
==============================
CUDA_CACHE_PATH=/data/users/yming/.nv/ComputeCache
LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64/:/usr/local/cuda-12.8/lib64/:
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY```
🐛 Describe the bug
llama4 maverick is failing to start due to the runtime error during shuffle_row.
this can be reproduced:
vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 --max_model_len 8192 --kv_cache_dtype fp8 --enable-expert-parallel --tensor-parallel-size 8 --trust-remote-code --gpu-memory-utilization 0.8 --disable-log-requests
This is likely related to #20762 @ElizaWszola
this can also be reproduced with pytest -s tests/models/multimodal/generation/test_maverick.py
, which requires only 2xH100 by running dummy version of maverick.
cc @yeqcharlotte @luccafong @houseroad
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1579, in moe_forward
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return self.forward_impl(hidden_states, router_logits)
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1489, in forward_impl
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] final_hidden_states = self.quant_method.apply(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py", line 959, in apply
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return cutlass_moe_fp8(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py", line 414, in cutlass_moe_fp8
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return fn(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/home/yming/uv_env/vllm/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return self._call_impl(*args, **kwargs)
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/home/yming/uv_env/vllm/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return forward_call(*args, **kwargs)
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 770, in forward
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] fused_out = self._maybe_chunk_fused_experts(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 545, in _maybe_chunk_fused_experts
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return self._do_fused_experts(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 492, in _do_fused_experts
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] self.fused_experts.apply(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py", line 314, in apply
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] run_cutlass_moe_fp8(
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py", line 160, in run_cutlass_moe_fp8
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/data/users/yming/gitrepos/vllm/vllm/_custom_ops.py", line 908, in shuffle_rows
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] File "/home/yming/uv_env/vllm/lib64/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] return self._op(*args, **(kwargs or {}))
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=938767) ERROR 07-21 11:16:11 [multiproc_executor.py:546] RuntimeError: num_cols must be divisible by 128 / sizeof(input_tensor.scalar_type()) / 8
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
houseroad
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingllamaRelated to Llama modelsRelated to Llama models
Type
Projects
Status
Done