-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Feature] Extend batch invariant torch.compile to B200 #27856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fade2e2 to
f827c17
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
47832c6 to
3878d3c
Compare
504d64e to
3ac6ae4
Compare
479a283 to
37beba9
Compare
Signed-off-by: PaulZhang12 <paulzhan@fb.com>
37beba9 to
df7d63a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
Purpose
This PR resolves issues with torch.compile + cudagraphs batch invariance issues on B200. Namely,
trtllm_attentionon B200 + cudagraphs causes issues.We extend all the unit tests as well to use torch.compile for evaluating batch invariance, and also disable GEMM custom operator overriding on PyTorch 2.10+, as it now contains the batch invariant cuda overrides, such as pytorch/pytorch#166735. For PyTorch 2.9, B200 still requires the custom Triton GEMM kernels.
Test Plan
The following tests are run on both H100 and B200.
pytest tests/v1/generation/test_batch_invariance.pyVLLM_TEST_MODEL="deepseek-ai/DeepSeek-V3" VLLM_TEST_TP_SIZE=8 VLLM_ATTENTION_BACKEND="FLASH_ATTN_MLA" pytest tests/v1/generation/test_batch_invariance.py -k "test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN_MLA]"Test Result
Performance
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND="TRITON_MLA" vllm serve deepseek-ai/DeepSeek-R1 -tp 8 --enable-expert-parallel --port 9256 --gpu_memory_utilization 0.95 --max_model_len 40960vllm bench serve --model deepseek-ai/DeepSeek-R1 --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 4 --random-output-len 64 --request-rate inf --num-prompts 256B200 torch 2.9 with batch invariant overrides, 15% throughput gains
B200 torch 2.9 eager
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.