diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a020b0d276be..f823ddd128dd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -347,8 +347,7 @@ steps: - vllm/v1/attention - tests/v1/attention commands: - - export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this - - pytest -v -s v1/attention + - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins source_file_dependencies: diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6659b3eb1e98..08aeb6f298f6 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -295,6 +295,7 @@ def _test_backend_correctness( block_size: int = 16, atol: float = 1e-2, rtol: float = 1e-2, + tensor_parallel_size: int = 1, ): """ Test that all backends produce similar outputs to a reference implementation @@ -310,13 +311,38 @@ def _test_backend_correctness( 4. Running each vLLM attention backend with the new queries and the simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. """ current_platform.seed_everything(42) + + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr( + temp_config.hf_text_config, "num_key_value_heads", None + ) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + vllm_config = create_vllm_config( model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements max_model_len=max(batch_spec.seq_lens), block_size=block_size, num_gpu_blocks=8192, + hf_config_override=hf_config_override, ) device = torch.device("cuda:0") @@ -503,7 +529,10 @@ def error_msg(msg: str, backend_name: str): ], ) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_causal_backend_correctness(batch_spec_name: str, model: str): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_causal_backend_correctness( + batch_spec_name: str, model: str, tensor_parallel_size: int +): """Test backend's correctness with causal attention.""" def causal_mask_mod( @@ -523,12 +552,23 @@ def causal_mask_mod( SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) + _test_backend_correctness( + batch_spec, + model, + SMALL_BLOCK_BACKENDS, + causal_mask_mod, + tensor_parallel_size=tensor_parallel_size, + ) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: _test_backend_correctness( - batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 + batch_spec, + model, + LARGE_BLOCK_BACKENDS, + causal_mask_mod, + block_size=128, + tensor_parallel_size=tensor_parallel_size, ) @@ -545,7 +585,10 @@ def causal_mask_mod( ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], ) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) -def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_sliding_window_backend_correctness( + batch_spec_name: str, model: str, tensor_parallel_size: int +): """Test backend's correctness with sliding window attention.""" def sliding_window_mask_mod( @@ -575,7 +618,11 @@ def sliding_window_mask_mod( x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] _test_backend_correctness( - batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn + batch_spec, + model, + SMALL_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + tensor_parallel_size=tensor_parallel_size, ) # Fast FlexAttention needs to run with block_size=128 @@ -586,4 +633,5 @@ def sliding_window_mask_mod( LARGE_BLOCK_BACKENDS, sliding_window_mask_mod_fn, block_size=128, + tensor_parallel_size=tensor_parallel_size, ) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index cda4fb11c096..5679fafe63ee 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -394,8 +394,11 @@ def run_attention_backend( "spec_decode_medium", ], ) -@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) -def test_backend_correctness(dist_init, batch_spec_name: str, model: str): +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) +def test_backend_correctness( + dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int +): """ Test that all backends produce similar outputs to a reference implementation using torch.nn.functional.scaled_dot_product_attention. @@ -410,6 +413,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): 4. Running each vLLM attention backend with the new queries and the simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. """ batch_spec = BATCH_SPECS[batch_spec_name] @@ -423,11 +431,30 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Add 1 for null block at index 0, and some buffer num_gpu_blocks = required_blocks + 1 + 100 + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr( + temp_config.hf_text_config, "num_key_value_heads", None + ) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + vllm_config = create_vllm_config( model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=num_gpu_blocks, block_size=default_block_size, + hf_config_override=hf_config_override, ) # For spec decode tests, add a speculative_config to set the reorder_batch_threshold diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 02324d2aca6e..b34d587eb362 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) -def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_sparse_backend_decode_correctness( + dist_init, batch_name, kv_cache_dtype, tensor_parallel_size +): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -135,8 +138,11 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype total_cache_tokens = sum(batch_spec.seq_lens) block_size = 64 + # Note: We use TP=1 to avoid multi-GPU requirements in CI. + # The test simulates head partitioning via mocked methods below. vllm_config = create_vllm_config( model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", + tensor_parallel_size=1, max_model_len=max_seqlen, num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), block_size=block_size, @@ -156,7 +162,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype ) model_config.dtype = dtype model_config.get_num_attention_heads = MethodType( - lambda self, parallel_config: num_heads, model_config + lambda self, parallel_config: max(1, num_heads // tensor_parallel_size), + model_config, ) model_config.get_num_kv_heads = MethodType( lambda self, parallel_config: 1, model_config