Skip to content
34 changes: 29 additions & 5 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _test_backend_correctness(
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
tensor_parallel_size: int = 1,
):
"""
Test that all backends produce similar outputs to a reference implementation
Expand All @@ -314,6 +315,7 @@ def _test_backend_correctness(
current_platform.seed_everything(42)
vllm_config = create_vllm_config(
model_name=model,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192,
Expand Down Expand Up @@ -503,7 +505,10 @@ def error_msg(msg: str, backend_name: str):
],
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention."""

def causal_mask_mod(
Expand All @@ -523,12 +528,23 @@ def causal_mask_mod(
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
_test_backend_correctness(
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
causal_mask_mod,
tensor_parallel_size=tensor_parallel_size,
)

# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)


Expand All @@ -545,7 +561,10 @@ def causal_mask_mod(
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
)
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sliding_window_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""

def sliding_window_mask_mod(
Expand Down Expand Up @@ -575,7 +594,11 @@ def sliding_window_mask_mod(
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
tensor_parallel_size=tensor_parallel_size,
)

# Fast FlexAttention needs to run with block_size=128
Expand All @@ -586,4 +609,5 @@ def sliding_window_mask_mod(
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)
6 changes: 5 additions & 1 deletion tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ def run_attention_backend(
],
)
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""
Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention.
Expand Down Expand Up @@ -368,6 +371,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):

vllm_config = create_vllm_config(
model_name=model,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=block_size,
Expand Down
9 changes: 7 additions & 2 deletions tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla(

@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")

Expand All @@ -137,6 +140,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype

vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_seqlen,
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size,
Expand All @@ -156,7 +160,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config
lambda self, parallel_config: num_heads // parallel_config.tensor_parallel_size,
model_config,
)
model_config.get_num_kv_heads = MethodType(
lambda self, parallel_config: 1, model_config
Expand Down