Skip to content
3 changes: 1 addition & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ steps:
- vllm/v1/attention
- tests/v1/attention
commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this

- label: V1 Test others (CPU) # 5 mins
source_file_dependencies:
Expand Down
58 changes: 53 additions & 5 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _test_backend_correctness(
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
tensor_parallel_size: int = 1,
):
"""
Test that all backends produce similar outputs to a reference implementation
Expand All @@ -310,13 +311,38 @@ def _test_backend_correctness(
4. Running each vLLM attention backend with the new queries and the
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.

Note: When tensor_parallel_size > 1, we simulate the head partitioning
by overriding the model config to use fewer heads, without requiring
multiple GPUs. This tests that backends work correctly with different
head counts.
"""
current_platform.seed_everything(42)

hf_config_override = None
if tensor_parallel_size > 1:
from vllm.config import ModelConfig

temp_config = ModelConfig(model=model, max_model_len=1)
original_num_heads = temp_config.hf_text_config.num_attention_heads
original_num_kv_heads = getattr(
temp_config.hf_text_config, "num_key_value_heads", None
)
hf_config_override = {
"num_attention_heads": original_num_heads // tensor_parallel_size,
}
if original_num_kv_heads is not None:
hf_config_override["num_key_value_heads"] = max(
1, original_num_kv_heads // tensor_parallel_size
)

vllm_config = create_vllm_config(
model_name=model,
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192,
hf_config_override=hf_config_override,
)
device = torch.device("cuda:0")

Expand Down Expand Up @@ -503,7 +529,10 @@ def error_msg(msg: str, backend_name: str):
],
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention."""

def causal_mask_mod(
Expand All @@ -523,12 +552,23 @@ def causal_mask_mod(
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
_test_backend_correctness(
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
causal_mask_mod,
tensor_parallel_size=tensor_parallel_size,
)

# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)


Expand All @@ -545,7 +585,10 @@ def causal_mask_mod(
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
)
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sliding_window_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""

def sliding_window_mask_mod(
Expand Down Expand Up @@ -575,7 +618,11 @@ def sliding_window_mask_mod(
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
tensor_parallel_size=tensor_parallel_size,
)

# Fast FlexAttention needs to run with block_size=128
Expand All @@ -586,4 +633,5 @@ def sliding_window_mask_mod(
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)
31 changes: 29 additions & 2 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,11 @@ def run_attention_backend(
"spec_decode_medium",
],
)
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""
Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention.
Expand All @@ -410,6 +413,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
4. Running each vLLM attention backend with the new queries and the
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.

Note: When tensor_parallel_size > 1, we simulate the head partitioning
by overriding the model config to use fewer heads, without requiring
multiple GPUs. This tests that backends work correctly with different
head counts.
"""

batch_spec = BATCH_SPECS[batch_spec_name]
Expand All @@ -423,11 +431,30 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100

hf_config_override = None
if tensor_parallel_size > 1:
from vllm.config import ModelConfig

temp_config = ModelConfig(model=model, max_model_len=1)
original_num_heads = temp_config.hf_text_config.num_attention_heads
original_num_kv_heads = getattr(
temp_config.hf_text_config, "num_key_value_heads", None
)
hf_config_override = {
"num_attention_heads": original_num_heads // tensor_parallel_size,
}
if original_num_kv_heads is not None:
hf_config_override["num_key_value_heads"] = max(
1, original_num_kv_heads // tensor_parallel_size
)

vllm_config = create_vllm_config(
model_name=model,
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=default_block_size,
hf_config_override=hf_config_override,
)

# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
Expand Down
11 changes: 9 additions & 2 deletions tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla(

@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")

Expand All @@ -135,8 +138,11 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64

# Note: We use TP=1 to avoid multi-GPU requirements in CI.
# The test simulates head partitioning via mocked methods below.
vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=1,
max_model_len=max_seqlen,
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size,
Expand All @@ -156,7 +162,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config
lambda self, parallel_config: max(1, num_heads // tensor_parallel_size),
model_config,
)
model_config.get_num_kv_heads = MethodType(
lambda self, parallel_config: 1, model_config
Expand Down
Loading