Skip to content

Conversation

hsliuustc0106
Copy link

@hsliuustc0106 hsliuustc0106 commented Sep 2, 2025

What this PR does / why we need it?

This PR is associated with #2607 which enables DP for ViT in Qwen-2.5-VL.

There are multiple reasons that we should have ViT implemented as a DP:

The ViT are small models, the TP all reduce incurred a larger overhead than the gain from accelerating through TP.
ViT are not captured in cuda graphs or torch compile graph, thus the kernel overhead and all reduce overhead will be higher.

Does this PR introduce any user-facing change?

add the arg selection for mm-encoder-tp-mode for data-parallelism, below is an example for DP for ViT and TP4 for LLM backbone

vllm serve
/workspace/models/Qwen2.5-VL-3B-Instruct
--port 5580 --host 0.0.0.0
--max-num-seqs 128 --dtype bfloat16 --max-model-len=8192
--no-enable-prefix-caching --trust-remote-code -tp 4
--allowed-local-media-path /workspace/l00807937/
--gpu-memory-utilization=0.93
--enforce-eager
--mm-encoder-tp-mode data ##

How was this patch tested?

vllm: 0.10.0RC1
vllm-ascend: 0.10.0RC1

Benchmark test


1. TP=4 Case
**Test Plan**
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
vllm serve \
    /workspace/models/Qwen2.5-VL-3B-Instruct \
    --port 5580 --host 0.0.0.0 \
    --max-num-seqs 128 --dtype bfloat16 --max-model-len=8192 \
    --no-enable-prefix-caching --trust-remote-code -tp 4 \
    --allowed-local-media-path /workspace/l00807937/ \
    --gpu-memory-utilization=0.93 \
    --enforce-eager \
    --mm-encoder-tp-mode data

**Test Result**
baseline: without --mm-encoder-tp-mode data
============ Serving Benchmark Result ============
Successful requests:                     99        
Benchmark duration (s):                  28.79     
Total input tokens:                      9959      
Total generated tokens:                  10707     
Request throughput (req/s):              3.44      
Output token throughput (tok/s):         371.96    
Total Token throughput (tok/s):          717.94    
---------------Time to First Token----------------
Mean TTFT (ms):                          7711.37   
Median TTFT (ms):                        6832.34   
P99 TTFT (ms):                           17305.82  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          177.03    
Median TPOT (ms):                        161.73    
P99 TPOT (ms):                           413.63    
---------------Inter-token Latency----------------
Mean ITL (ms):                           157.30    
Median ITL (ms):                         90.89     
P99 ITL (ms):                            640.97    
==================================================
DP4: with --mm-encoder-tp-mode data
============ Serving Benchmark Result ============
Successful requests:                     99        
Benchmark duration (s):                  25.67     
Total input tokens:                      9959      
Total generated tokens:                  10749     
Request throughput (req/s):              3.86      
Output token throughput (tok/s):         418.82    
Total Token throughput (tok/s):          806.85    
---------------Time to First Token----------------
Mean TTFT (ms):                          6393.85   
Median TTFT (ms):                        5437.94   
P99 TTFT (ms):                           14115.35  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          158.26    
Median TPOT (ms):                        150.12    
P99 TPOT (ms):                           346.36    
---------------Inter-token Latency----------------
Mean ITL (ms):                           140.90    
Median ITL (ms):                         90.94     
P99 ITL (ms):                            439.49    
==================================================

2. TP=2 Case
**Test Plan**
export ASCEND_RT_VISIBLE_DEVICES=0,1
vllm serve \
    /workspace/models/Qwen2.5-VL-3B-Instruct \
    --port 5580 --host 0.0.0.0 \
    --max-num-seqs 128 --dtype bfloat16 --max-model-len=8192 \
    --no-enable-prefix-caching --trust-remote-code -tp 2 \
    --allowed-local-media-path /workspace/l00807937/ \
    --gpu-memory-utilization=0.93 \
    --enforce-eager \
    --mm-encoder-tp-mode data

**Test Result**
baseline: without --mm-encoder-tp-mode data
============ Serving Benchmark Result ============
Successful requests:                     99        
Benchmark duration (s):                  31.23     
Total input tokens:                      9959      
Total generated tokens:                  10732     
Request throughput (req/s):              3.17      
Output token throughput (tok/s):         343.69    
Total Token throughput (tok/s):          662.63    
---------------Time to First Token----------------
Mean TTFT (ms):                          8679.98   
Median TTFT (ms):                        7558.25   
P99 TTFT (ms):                           19444.89  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          188.91    
Median TPOT (ms):                        180.24    
P99 TPOT (ms):                           464.39    
---------------Inter-token Latency----------------
Mean ITL (ms):                           168.80    
Median ITL (ms):                         92.44     
P99 ITL (ms):                            725.62    
==================================================
DP2: with --mm-encoder-tp-mode data
============ Serving Benchmark Result ============
Successful requests:                     99        
Benchmark duration (s):                  27.18     
Total input tokens:                      9959      
Total generated tokens:                  10707     
Request throughput (req/s):              3.64      
Output token throughput (tok/s):         393.87    
Total Token throughput (tok/s):          760.23    
---------------Time to First Token----------------
Mean TTFT (ms):                          6903.44   
Median TTFT (ms):                        5630.95   
P99 TTFT (ms):                           15328.67  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          168.99    
Median TPOT (ms):                        158.29    
P99 TPOT (ms):                           372.06    
---------------Inter-token Latency----------------
Mean ITL (ms):                           150.38    
Median ITL (ms):                         94.54     
P99 ITL (ms):                            471.63    
==================================================


- vLLM version: main
- vLLM main: https://github.com/vllm-project/vllm/commit/267c80d31f6b77092a5d5903da64556ac15c4d4d

Junhong and others added 5 commits September 2, 2025 11:09
Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong <liujunhong11@huawei.com>
Copy link

github-actions bot commented Sep 2, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Data Parallelism (DP) for the Vision Transformer (ViT) in Qwen-2.5-VL, which can improve performance for smaller models by avoiding tensor parallelism overhead. The changes introduce a use_data_parallel flag and a new execution path for DP. My review found a critical issue in the implementation where an incorrect attribute access would lead to a runtime error. I've provided a code suggestion to fix this.

Comment on lines 398 to 403
def _normalize_grid_thw(self, grid_thw: Union[torch.Tensor, list[list[int]]]) -> torch.Tensor:
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, device=self.device)
elif not isinstance(grid_thw, torch.Tensor):
raise TypeError(f"Expected input type is torch.Tensor or list of lists, got {type(grid_thw)}")
return grid_thw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

torch.nn.Module does not have a .device attribute, so calling self.device will raise an AttributeError at runtime. A more robust approach to get the module's device is to inspect one of its parameters, for example, by using next(self.parameters()).device.

Suggested change
def _normalize_grid_thw(self, grid_thw: Union[torch.Tensor, list[list[int]]]) -> torch.Tensor:
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, device=self.device)
elif not isinstance(grid_thw, torch.Tensor):
raise TypeError(f"Expected input type is torch.Tensor or list of lists, got {type(grid_thw)}")
return grid_thw
def _normalize_grid_thw(self, grid_thw: Union[torch.Tensor, list[list[int]]]) -> torch.Tensor:
if isinstance(grid_thw, list):
device = next(self.parameters()).device
grid_thw = torch.tensor(grid_thw, device=device)
elif not isinstance(grid_thw, torch.Tensor):
raise TypeError(f"Expected input type is torch.Tensor or list of lists, got {type(grid_thw)}")
return grid_thw

Junhong and others added 6 commits September 3, 2025 09:31
Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong <liujunhong11@huawei.com>
Issue 2607 fix bug in test
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants