Skip to content

Commit bec5419

Browse files
committed
Merge branch 'upstream_main' into v1_encoder_only
2 parents 837e51b + fe56180 commit bec5419

File tree

7 files changed

+261
-31
lines changed

7 files changed

+261
-31
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ run_and_track_test 1 "test_compilation.py" \
135135
run_and_track_test 2 "test_basic.py" \
136136
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py"
137137
run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \
138-
"python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
138+
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
139139
run_and_track_test 4 "test_quantization_accuracy.py" \
140140
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py"
141141
run_and_track_test 5 "examples/offline_inference/tpu.py" \

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
386386

387387
# Install FlashInfer from source
388388
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
389-
ARG FLASHINFER_GIT_REF="v0.2.8"
389+
ARG FLASHINFER_GIT_REF="v0.2.9rc1"
390390
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
391391
. /etc/environment
392392
git clone --depth 1 --recursive --shallow-submodules \
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Expert Parallel Deployment
2+
3+
vLLM supports Expert Parallelism (EP), which allows experts in Mixture-of-Experts (MoE) models to be deployed on separate GPUs, increasing locality, efficiency, and throughput overall.
4+
5+
EP is typically coupled with Data Parallelism (DP). While DP can be used independently of EP, EP is more efficient when used in conjunction with DP. You can read more about data parallelism [here](data_parallel_deployment.md).
6+
7+
## Prerequisites
8+
9+
Before using EP, you need to install the necessary dependencies. We are actively working on making this easier in the future:
10+
11+
1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](gh-file:tools/ep_kernels).
12+
2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation).
13+
3. **For disaggregated serving**: Install UCX and NIXL following the [script](gh-file:tools/install_nixl.sh).
14+
15+
### Backend Selection Guide
16+
17+
vLLM provides three communication backends for EP:
18+
19+
| Backend | Use Case | Features | Best For |
20+
|---------|----------|----------|----------|
21+
| `pplx` | Single node | Chunked prefill support | Development, best for intra-node deployments |
22+
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout | High-throughput scenarios, prefill-dominated workloads |
23+
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout | Low-latency scenarios, decode-dominated workloads |
24+
25+
## Single Node Deployment
26+
27+
!!! warning
28+
EP is an experimental feature. Argument names and default values may change in the future.
29+
30+
### Configuration
31+
32+
Enable EP by setting the `--enable-expert-parallel` flag. The EP size is automatically calculated as:
33+
34+
```
35+
EP_SIZE = TP_SIZE × DP_SIZE
36+
```
37+
38+
Where:
39+
- `TP_SIZE`: Tensor parallel size (always 1 for now)
40+
- `DP_SIZE`: Data parallel size
41+
- `EP_SIZE`: Expert parallel size (computed automatically)
42+
43+
### Example Command
44+
45+
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
46+
47+
```bash
48+
# Single node EP deployment with pplx backend
49+
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 \
50+
vllm serve deepseek-ai/DeepSeek-V3-0324 \
51+
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
52+
--data-parallel-size 8 \ # Data parallelism across 8 processes
53+
--enable-expert-parallel # Enable expert parallelism
54+
```
55+
56+
## Multi-Node Deployment
57+
58+
For multi-node deployment, use the DeepEP communication kernel with one of two modes (see [Backend Selection Guide](#backend-selection-guide) above).
59+
60+
### Deployment Steps
61+
62+
1. **Run one command per node** - Each node requires its own launch command
63+
2. **Configure networking** - Ensure proper IP addresses and port configurations
64+
3. **Set node roles** - First node handles requests, additional nodes run in headless mode
65+
66+
### Example: 2-Node Deployment
67+
68+
The following example deploys `DeepSeek-V3-0324` across 2 nodes using `deepep_low_latency` mode:
69+
70+
```bash
71+
# Node 1 (Primary - handles incoming requests)
72+
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
73+
vllm serve deepseek-ai/DeepSeek-V3-0324 \
74+
--tensor-parallel-size 1 \ # TP size per node
75+
--enable-expert-parallel \ # Enable EP
76+
--data-parallel-size 16 \ # Total DP size across all nodes
77+
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
78+
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
79+
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
80+
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
81+
82+
# Node 2 (Secondary - headless mode, no API server)
83+
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
84+
vllm serve deepseek-ai/DeepSeek-V3-0324 \
85+
--tensor-parallel-size 1 \ # TP size per node
86+
--enable-expert-parallel \ # Enable EP
87+
--data-parallel-size 16 \ # Total DP size across all nodes
88+
--data-parallel-size-local 8 \ # Local DP size on this node
89+
--data-parallel-start-rank 8 \ # Starting rank offset for this node
90+
--data-parallel-address 192.168.1.100 \ # IP of primary node (Node 1)
91+
--data-parallel-rpc-port 13345 \ # Same RPC port as primary
92+
--headless # No API server, worker only
93+
```
94+
95+
### Key Configuration Notes
96+
97+
- **Headless mode**: Secondary nodes run with `--headless` flag, meaning all client requests are handled by the primary node
98+
- **Rank calculation**: `--data-parallel-start-rank` should equal the cumulative local DP size of previous nodes
99+
- **Load scaling**: Adjust `--api-server-count` on the primary node to handle higher request loads
100+
101+
### Network Configuration
102+
103+
!!! important "InfiniBand Clusters"
104+
On InfiniBand networked clusters, set this environment variable to prevent initialization hangs:
105+
```bash
106+
export GLOO_SOCKET_IFNAME=eth0
107+
```
108+
This ensures torch distributed group discovery uses Ethernet instead of InfiniBand for initial setup.
109+
110+
## Expert Parallel Load Balancer (EPLB)
111+
112+
While MoE models are typically trained so that each expert receives a similar number of tokens, in practice the distribution of tokens across experts can be highly skewed. vLLM provides an Expert Parallel Load Balancer (EPLB) to redistribute expert mappings across EP ranks, evening the load across experts.
113+
114+
### Configuration
115+
116+
Enable EPLB with the `--enable-eplb` flag.
117+
118+
!!! note "Model Support"
119+
Currently only DeepSeek V3 architecture is supported.
120+
121+
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
122+
123+
### EPLB Parameters
124+
125+
| Parameter | Description | Default |
126+
|-----------|-------------|---------|
127+
| `--eplb-window-size` | Number of engine steps to track for rebalancing decisions | - |
128+
| `--eplb-step-interval` | Frequency of rebalancing (every N engine steps) | - |
129+
| `--eplb-log-balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
130+
| `--num-redundant-experts` | Additional global experts per EP rank beyond equal distribution | `0` |
131+
132+
### Expert Distribution Formula
133+
134+
- **Default**: Each EP rank has `NUM_TOTAL_EXPERTS ÷ NUM_EP_RANKS` experts
135+
- **With redundancy**: Each EP rank has `(NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS` experts
136+
137+
### Example Command
138+
139+
Single node deployment with EPLB enabled:
140+
141+
```bash
142+
# Single node with EPLB load balancing
143+
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V3-0324 \
144+
--tensor-parallel-size 1 \ # Tensor parallelism
145+
--data-parallel-size 8 \ # Data parallelism
146+
--enable-expert-parallel \ # Enable EP
147+
--enable-eplb \ # Enable load balancer
148+
--eplb-log-balancedness \ # Log balancing metrics
149+
--eplb-window-size 1000 \ # Track last 1000 engine steps
150+
--eplb-step-interval 3000 # Rebalance every 3000 steps
151+
```
152+
153+
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--num-redundant-experts` to 32 in large scale use cases so the most popular experts are always available.
154+
155+
## Disaggregated Serving (Prefill/Decode Split)
156+
157+
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
158+
159+
### Architecture Overview
160+
161+
- **Prefill Instance**: Uses `deepep_high_throughput` backend for optimal prefill performance
162+
- **Decode Instance**: Uses `deepep_low_latency` backend for minimal decode latency
163+
- **KV Cache Transfer**: Connects instances via NIXL or other KV connectors
164+
165+
### Setup Steps
166+
167+
1. **Install KV Connector**: Install NIXL using the [installation script](gh-file:tools/install_nixl.sh)
168+
169+
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
170+
171+
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
172+
173+
### Client Orchestration Example
174+
175+
```python
176+
from openai import OpenAI
177+
import uuid
178+
179+
try:
180+
# 1: Set up clients for prefill and decode instances
181+
openai_api_key = "EMPTY" # vLLM doesn't require a real API key
182+
183+
# Replace these IP addresses with your actual instance addresses
184+
prefill_client = OpenAI(
185+
api_key=openai_api_key,
186+
base_url="http://192.168.1.100:8000/v1", # Prefill instance URL
187+
)
188+
decode_client = OpenAI(
189+
api_key=openai_api_key,
190+
base_url="http://192.168.1.101:8001/v1", # Decode instance URL
191+
)
192+
193+
# Get model name from prefill instance
194+
models = prefill_client.models.list()
195+
model = models.data[0].id
196+
print(f"Using model: {model}")
197+
198+
# 2: Prefill Phase
199+
# Generate unique request ID to link prefill and decode operations
200+
request_id = str(uuid.uuid4())
201+
print(f"Request ID: {request_id}")
202+
203+
prefill_response = prefill_client.completions.create(
204+
model=model,
205+
# Prompt must exceed vLLM's block size (16 tokens) for PD to work
206+
prompt="Write a detailed explanation of Paged Attention for Transformers works including the management of KV cache for multi-turn conversations",
207+
max_tokens=1, # Force prefill-only operation
208+
extra_body={
209+
"kv_transfer_params": {
210+
"do_remote_decode": True, # Enable remote decode
211+
"do_remote_prefill": False, # This is the prefill instance
212+
"remote_engine_id": None, # Will be populated by vLLM
213+
"remote_block_ids": None, # Will be populated by vLLM
214+
"remote_host": None, # Will be populated by vLLM
215+
"remote_port": None # Will be populated by vLLM
216+
}
217+
},
218+
extra_headers={"X-Request-Id": request_id}
219+
)
220+
221+
print("-" * 50)
222+
print("✓ Prefill completed successfully")
223+
print(f"Prefill response: {prefill_response.choices[0].text}")
224+
225+
# 3: Decode Phase
226+
# Transfer KV cache parameters from prefill to decode instance
227+
decode_response = decode_client.completions.create(
228+
model=model,
229+
prompt="This prompt is ignored during decode", # Original prompt not needed
230+
max_tokens=150, # Generate up to 150 tokens
231+
extra_body={
232+
"kv_transfer_params": prefill_response.kv_transfer_params # Pass KV cache info
233+
},
234+
extra_headers={"X-Request-Id": request_id} # Same request ID
235+
)
236+
237+
print("-" * 50)
238+
print("✓ Decode completed successfully")
239+
print(f"Final response: {decode_response.choices[0].text}")
240+
241+
except Exception as e:
242+
print(f"❌ Error during disaggregated serving: {e}")
243+
print("Check that both prefill and decode instances are running and accessible")
244+
```

tests/entrypoints/llm/test_accuracy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
7373
if current_platform.is_tpu():
7474
# Limit compilation time for TPU V1
7575

76-
# xet doesn't work well for both Qwen/Qwen3-1.7B and
77-
# google/gemma-3-1b-it
78-
m.setenv("HF_HUB_DISABLE_XET", "1")
7976
more_args = "max_model_len=2048,max_num_seqs=64"
8077

8178
# Add TP test (if provided)

vllm/attention/backends/flashinfer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,16 +1169,12 @@ def forward(
11691169
query=decode_query,
11701170
kv_cache=kv_cache.permute(*stride_order),
11711171
workspace_buffer=workspace_buffer,
1172-
num_heads=num_heads,
1173-
num_kv_heads=num_kv_heads,
1174-
scale=softmax_scale,
11751172
block_tables=attn_metadata.block_tables,
11761173
seq_lens=decode_meta.seq_lens_tensor,
1177-
block_size=attn_metadata.page_size,
11781174
max_seq_len=attn_metadata.max_decode_seq_len,
1179-
kv_cache_dtype=kv_cache_dtype,
1180-
k_scale=layer._k_scale_float,
1181-
v_scale=layer._v_scale_float)
1175+
bmm1_scale=layer._k_scale_float * softmax_scale,
1176+
bmm2_scale=layer._v_scale_float,
1177+
)
11821178

11831179
if prefill_output is None and decode_output is not None:
11841180
# Decode only batch.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -591,22 +591,20 @@ def determine_expert_map(
591591
if ep_size == 1:
592592
return (global_num_experts, None)
593593

594-
local_num_experts = global_num_experts // ep_size
594+
# Distribute experts as evenly as possible to each rank.
595+
base_experts = global_num_experts // ep_size
596+
remainder = global_num_experts % ep_size
597+
if ep_rank < remainder:
598+
local_num_experts = base_experts + 1
599+
else:
600+
local_num_experts = base_experts
595601

596602
# Create a tensor of size num_experts filled with -1
597603
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
598604
# Create a expert map for the local experts
599-
if ep_rank < (ep_size - 1):
600-
# Each non-last rank gets local_num_experts experts.
601-
expert_map[ep_rank * local_num_experts:
602-
(ep_rank + 1) * local_num_experts] = \
603-
torch.arange(0, local_num_experts, dtype=torch.int32)
604-
else:
605-
# All remaining experts are assigned to the last rank.
606-
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
607-
608-
expert_map[-local_num_experts:] = \
609-
torch.arange(0, local_num_experts, dtype=torch.int32)
605+
start_idx = ep_rank * base_experts + min(ep_rank, remainder)
606+
expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
607+
0, local_num_experts, dtype=torch.int32)
610608
return (local_num_experts, expert_map)
611609

612610

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -678,15 +678,10 @@ def forward(
678678
query=decode_query,
679679
kv_cache=kv_cache_permute,
680680
workspace_buffer=attn_metadata.workspace_buffer,
681-
num_heads=self.num_heads,
682-
num_kv_heads=self.num_kv_heads,
683-
scale=self.scale,
684681
block_tables=block_tables_decode,
685682
seq_lens=seq_lens_decode,
686-
block_size=attn_metadata.page_size,
687683
max_seq_len=attn_metadata.max_seq_len,
688-
kv_cache_dtype=self.kv_cache_dtype,
689-
k_scale=layer._k_scale_float,
690-
v_scale=layer._v_scale_float,
684+
bmm1_scale=layer._k_scale_float * self.scale,
685+
bmm2_scale=layer._v_scale_float,
691686
))
692687
return output_padded

0 commit comments

Comments
 (0)