Skip to content

Conversation

@PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Oct 31, 2025

Purpose

This PR resolves issues with torch.compile + cudagraphs batch invariance issues on B200. Namely, trtllm_attention on B200 + cudagraphs causes issues.

We extend all the unit tests as well to use torch.compile for evaluating batch invariance, and also disable GEMM custom operator overriding on PyTorch 2.10+, as it now contains the batch invariant cuda overrides, such as pytorch/pytorch#166735. For PyTorch 2.9, B200 still requires the custom Triton GEMM kernels.

Test Plan

The following tests are run on both H100 and B200.

pytest tests/v1/generation/test_batch_invariance.py
VLLM_TEST_MODEL="deepseek-ai/DeepSeek-V3" VLLM_TEST_TP_SIZE=8 VLLM_ATTENTION_BACKEND="FLASH_ATTN_MLA" pytest tests/v1/generation/test_batch_invariance.py -k "test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN_MLA]"

Test Result

Performance

VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND="TRITON_MLA" vllm serve deepseek-ai/DeepSeek-R1 -tp 8 --enable-expert-parallel --port 9256 --gpu_memory_utilization 0.95 --max_model_len 40960

vllm bench serve --model deepseek-ai/DeepSeek-R1 --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 4 --random-output-len 64 --request-rate inf --num-prompts 256

B200 torch 2.9 with batch invariant overrides, 15% throughput gains

============ Serving Benchmark Result ============
Successful requests:                     256       
Failed requests:                         0         
Benchmark duration (s):                  13.91     
Total input tokens:                      768       
Total generated tokens:                  16210     
Request throughput (req/s):              18.41     
Output token throughput (tok/s):         1165.62   
Peak output token throughput (tok/s):    1276.00   
Peak concurrent requests:                256.00    
Total Token throughput (tok/s):          1220.85   
---------------Time to First Token----------------
Mean TTFT (ms):                          679.74    
Median TTFT (ms):                        687.39    
P99 TTFT (ms):                           698.39    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          209.57    
Median TPOT (ms):                        209.51    
P99 TPOT (ms):                           212.54    
---------------Inter-token Latency----------------
Mean ITL (ms):                           209.56    
Median ITL (ms):                         209.37    
P99 ITL (ms):                            216.25    
==================================================

B200 torch 2.9 eager

============ Serving Benchmark Result ============
Successful requests:                     256       
Failed requests:                         0         
Benchmark duration (s):                  16.02     
Total input tokens:                      768       
Total generated tokens:                  16201     
Request throughput (req/s):              15.98     
Output token throughput (tok/s):         1011.35   
Peak output token throughput (tok/s):    1272.00   
Peak concurrent requests:                256.00    
Total Token throughput (tok/s):          1059.29   
---------------Time to First Token----------------
Mean TTFT (ms):                          670.45    
Median TTFT (ms):                        671.56    
P99 TTFT (ms):                           680.53    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          243.35    
Median TPOT (ms):                        243.33    
P99 TPOT (ms):                           243.64    
---------------Inter-token Latency----------------
Mean ITL (ms):                           243.34    
Median ITL (ms):                         243.29    
P99 ITL (ms):                            250.30    
==================================================
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@PaulZhang12 PaulZhang12 force-pushed the batch_invariant_b200 branch 2 times, most recently from 47832c6 to 3878d3c Compare November 3, 2025 16:32
@mergify mergify bot added the v1 label Nov 3, 2025
@PaulZhang12 PaulZhang12 force-pushed the batch_invariant_b200 branch 5 times, most recently from 504d64e to 3ac6ae4 Compare November 4, 2025 14:49
@PaulZhang12 PaulZhang12 changed the title Override inductor default mm with batch invariant one for B200 [Feature] Extend batch invariant torch.compile to B200 Nov 4, 2025
@PaulZhang12 PaulZhang12 force-pushed the batch_invariant_b200 branch 2 times, most recently from 479a283 to 37beba9 Compare November 4, 2025 14:53
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants