Skip to content

Commit 4bc400f

Browse files
[CI/Testing] Add basic single node dual batch overlap test (#27235)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent cac4c10 commit 4bc400f

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,7 @@ steps:
12231223
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
12241224
- pytest -v -s tests/distributed/test_context_parallel.py
12251225
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
1226+
- pytest -v -s tests/v1/distributed/test_dbo.py
12261227

12271228
##### B200 test #####
12281229
- label: Distributed Tests (B200) # optional
@@ -1233,6 +1234,7 @@ steps:
12331234
commands:
12341235
- pytest -v -s tests/distributed/test_context_parallel.py
12351236
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
1237+
- pytest -v -s tests/v1/distributed/test_dbo.py
12361238

12371239
##### RL Integration Tests #####
12381240
- label: Prime-RL Integration Test # 15min

tests/v1/distributed/test_dbo.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Test Dual Batch Overlap (DBO) with Data Parallelism + Expert Parallelism.
5+
6+
DBO is specifically designed for DP+EP scenarios to hide communication latency
7+
by overlapping computation of two batches. This test validates that DBO works
8+
correctly with the DeepSeek-V2-Lite model using GSM8K evaluation.
9+
"""
10+
11+
import pytest
12+
13+
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
14+
from tests.utils import RemoteOpenAIServer
15+
16+
MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat"
17+
DP_SIZE = 2
18+
19+
# GSM8K eval configuration
20+
NUM_QUESTIONS = 256 # Fast eval for CI; but must be large enough to hit dbo thresholds
21+
NUM_SHOTS = 5 # Few-shot examples
22+
MIN_ACCURACY = 0.62 # Expected 0.64 with 2% buffer (based on vLLM test data)
23+
24+
# Increase max_num_seqs to trigger DBO for decode batches
25+
# With 64 seqs, decode batches should exceed the 32 token threshold
26+
MAX_NUM_SEQS = 64 # Increased from 16 to trigger decode DBO
27+
28+
# DeepEP backends to test
29+
DEEPEP_BACKENDS = [
30+
"deepep_low_latency",
31+
"deepep_high_throughput",
32+
]
33+
34+
35+
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
36+
def test_dbo_dp_ep_gsm8k(all2all_backend: str, num_gpus_available):
37+
"""
38+
Test DBO with DP+EP using GSM8K evaluation.
39+
"""
40+
required_gpus = DP_SIZE
41+
42+
if num_gpus_available < required_gpus:
43+
pytest.skip(f"Need at least {required_gpus} GPUs (DP={DP_SIZE})")
44+
45+
# Server arguments for DBO + DP + EP
46+
server_args = [
47+
"--max-model-len",
48+
"4096",
49+
"--max-num-seqs",
50+
str(MAX_NUM_SEQS), # Use larger batch to trigger decode DBO
51+
"--trust-remote-code",
52+
# Note: Not using --enforce-eager to test DBO's alternate CUDA graph dispatching
53+
"--data-parallel-size",
54+
str(DP_SIZE),
55+
"--enable-expert-parallel",
56+
"--enable-dbo",
57+
# Fix threshold so we know we trigger DBO
58+
"--dbo-decode-token-threshold",
59+
"16",
60+
"--dbo-prefill-token-threshold",
61+
"256",
62+
"--all2all-backend",
63+
all2all_backend,
64+
]
65+
66+
with RemoteOpenAIServer(
67+
MODEL_NAME,
68+
server_args,
69+
max_wait_seconds=600, # Allow time for model loading with DP+EP
70+
) as remote_server:
71+
# Use host and port directly from RemoteOpenAIServer
72+
host = f"http://{remote_server.host}"
73+
port = remote_server.port
74+
75+
# Run GSM8K evaluation
76+
results = evaluate_gsm8k(
77+
num_questions=NUM_QUESTIONS,
78+
num_shots=NUM_SHOTS,
79+
host=host,
80+
port=port,
81+
)
82+
83+
# Validate accuracy is reasonable
84+
accuracy = results["accuracy"]
85+
assert accuracy >= MIN_ACCURACY, (
86+
f"DBO+DP+EP accuracy too low ({all2all_backend}): "
87+
f"{accuracy:.3f} < {MIN_ACCURACY:.3f} "
88+
f"(correct: {results['num_correct']}/{results['num_questions']})"
89+
)

0 commit comments

Comments
 (0)