@@ -295,6 +295,7 @@ def _test_backend_correctness(
295295 block_size : int = 16 ,
296296 atol : float = 1e-2 ,
297297 rtol : float = 1e-2 ,
298+ tensor_parallel_size : int = 1 ,
298299):
299300 """
300301 Test that all backends produce similar outputs to a reference implementation
@@ -310,13 +311,38 @@ def _test_backend_correctness(
310311 4. Running each vLLM attention backend with the new queries and the
311312 simulated paged KV cache.
312313 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
314+
315+ Note: When tensor_parallel_size > 1, we simulate the head partitioning
316+ by overriding the model config to use fewer heads, without requiring
317+ multiple GPUs. This tests that backends work correctly with different
318+ head counts.
313319 """
314320 current_platform .seed_everything (42 )
321+
322+ hf_config_override = None
323+ if tensor_parallel_size > 1 :
324+ from vllm .config import ModelConfig
325+
326+ temp_config = ModelConfig (model = model , max_model_len = 1 )
327+ original_num_heads = temp_config .hf_text_config .num_attention_heads
328+ original_num_kv_heads = getattr (
329+ temp_config .hf_text_config , "num_key_value_heads" , None
330+ )
331+ hf_config_override = {
332+ "num_attention_heads" : original_num_heads // tensor_parallel_size ,
333+ }
334+ if original_num_kv_heads is not None :
335+ hf_config_override ["num_key_value_heads" ] = max (
336+ 1 , original_num_kv_heads // tensor_parallel_size
337+ )
338+
315339 vllm_config = create_vllm_config (
316340 model_name = model ,
341+ tensor_parallel_size = 1 , # Always use TP=1 to avoid multi-GPU requirements
317342 max_model_len = max (batch_spec .seq_lens ),
318343 block_size = block_size ,
319344 num_gpu_blocks = 8192 ,
345+ hf_config_override = hf_config_override ,
320346 )
321347 device = torch .device ("cuda:0" )
322348
@@ -503,7 +529,10 @@ def error_msg(msg: str, backend_name: str):
503529 ],
504530)
505531@pytest .mark .parametrize ("model" , ["meta-llama/Meta-Llama-3-8B" ])
506- def test_causal_backend_correctness (batch_spec_name : str , model : str ):
532+ @pytest .mark .parametrize ("tensor_parallel_size" , [1 , 2 , 4 ])
533+ def test_causal_backend_correctness (
534+ batch_spec_name : str , model : str , tensor_parallel_size : int
535+ ):
507536 """Test backend's correctness with causal attention."""
508537
509538 def causal_mask_mod (
@@ -523,12 +552,23 @@ def causal_mask_mod(
523552 SMALL_BLOCK_BACKENDS = [
524553 x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
525554 ]
526- _test_backend_correctness (batch_spec , model , SMALL_BLOCK_BACKENDS , causal_mask_mod )
555+ _test_backend_correctness (
556+ batch_spec ,
557+ model ,
558+ SMALL_BLOCK_BACKENDS ,
559+ causal_mask_mod ,
560+ tensor_parallel_size = tensor_parallel_size ,
561+ )
527562
528563 # Fast FlexAttention needs to run with block_size=128
529564 if LARGE_BLOCK_BACKENDS :
530565 _test_backend_correctness (
531- batch_spec , model , LARGE_BLOCK_BACKENDS , causal_mask_mod , block_size = 128
566+ batch_spec ,
567+ model ,
568+ LARGE_BLOCK_BACKENDS ,
569+ causal_mask_mod ,
570+ block_size = 128 ,
571+ tensor_parallel_size = tensor_parallel_size ,
532572 )
533573
534574
@@ -545,7 +585,10 @@ def causal_mask_mod(
545585 ["small_decode" , "small_prefill" , "mixed_medium" , "large_decode" , "large_prefill" ],
546586)
547587@pytest .mark .parametrize ("model" , ["microsoft/Phi-tiny-MoE-instruct" ])
548- def test_sliding_window_backend_correctness (batch_spec_name : str , model : str ):
588+ @pytest .mark .parametrize ("tensor_parallel_size" , [1 , 2 , 4 ])
589+ def test_sliding_window_backend_correctness (
590+ batch_spec_name : str , model : str , tensor_parallel_size : int
591+ ):
549592 """Test backend's correctness with sliding window attention."""
550593
551594 def sliding_window_mask_mod (
@@ -575,7 +618,11 @@ def sliding_window_mask_mod(
575618 x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
576619 ]
577620 _test_backend_correctness (
578- batch_spec , model , SMALL_BLOCK_BACKENDS , sliding_window_mask_mod_fn
621+ batch_spec ,
622+ model ,
623+ SMALL_BLOCK_BACKENDS ,
624+ sliding_window_mask_mod_fn ,
625+ tensor_parallel_size = tensor_parallel_size ,
579626 )
580627
581628 # Fast FlexAttention needs to run with block_size=128
@@ -586,4 +633,5 @@ def sliding_window_mask_mod(
586633 LARGE_BLOCK_BACKENDS ,
587634 sliding_window_mask_mod_fn ,
588635 block_size = 128 ,
636+ tensor_parallel_size = tensor_parallel_size ,
589637 )
0 commit comments