Skip to content

Conversation

@huachenheli
Copy link
Contributor

@huachenheli huachenheli commented Oct 29, 2025

Purpose

After #23207, Qwen2.5 VL's vision model has dynamic slicing issue on cuda with torch.compile. Temporarily disabling it for now.

(EngineCore_DP0 pid=1671620) Process EngineCore_DP0:
(EngineCore_DP0 pid=1671620) Traceback (most recent call last):
(EngineCore_DP0 pid=1671620)   File "/usr/lib64/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=1671620)     self.run()
(EngineCore_DP0 pid=1671620)   File "/usr/lib64/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=1671620)     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/engine/core.py", line 783, in run_engine_core
(EngineCore_DP0 pid=1671620)     raise e
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/engine/core.py", line 770, in run_engine_core
(EngineCore_DP0 pid=1671620)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=1671620)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/engine/core.py", line 538, in __init__
(EngineCore_DP0 pid=1671620)     super().__init__(
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=1671620)     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=1671620)                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/engine/core.py", line 218, in _initialize_kv_caches
(EngineCore_DP0 pid=1671620)     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore_DP0 pid=1671620)                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/executor/abstract.py", line 123, in determine_available_memory
(EngineCore_DP0 pid=1671620)     return self.collective_rpc("determine_available_memory")
(EngineCore_DP0 pid=1671620)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/executor/uniproc_executor.py", line 73, in collective_rpc
(EngineCore_DP0 pid=1671620)     return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=1671620)             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/serial_utils.py", line 459, in run_method
(EngineCore_DP0 pid=1671620)     return func(*args, **kwargs)
(EngineCore_DP0 pid=1671620)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/uv_env/vllm/lib64/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=1671620)     return func(*args, **kwargs)
(EngineCore_DP0 pid=1671620)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/worker/gpu_worker.py", line 284, in determine_available_memory
(EngineCore_DP0 pid=1671620)     self.model_runner.profile_run()
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/v1/worker/gpu_model_runner.py", line 3713, in profile_run
(EngineCore_DP0 pid=1671620)     dummy_encoder_outputs = self.model.get_multimodal_embeddings(
(EngineCore_DP0 pid=1671620)                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1589, in get_multimodal_embeddings
(EngineCore_DP0 pid=1671620)     video_embeddings = self._process_video_input(multimodal_input)
(EngineCore_DP0 pid=1671620)                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1420, in _process_video_input
(EngineCore_DP0 pid=1671620)     video_embeds = self.visual(
(EngineCore_DP0 pid=1671620)                    ^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/uv_env/vllm/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(EngineCore_DP0 pid=1671620)     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=1671620)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/uv_env/vllm/lib64/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(EngineCore_DP0 pid=1671620)     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=1671620)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 942, in forward
(EngineCore_DP0 pid=1671620)     hidden_states = blk(
(EngineCore_DP0 pid=1671620)                     ^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/compilation/decorators.py", line 470, in __call__
(EngineCore_DP0 pid=1671620)     output = self.compiled_callable(*args, **kwargs)
(EngineCore_DP0 pid=1671620)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/uv_env/vllm/lib64/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
(EngineCore_DP0 pid=1671620)     raise e.with_traceback(None) from e.__cause__  # User compiler error
(EngineCore_DP0 pid=1671620)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1671620) torch._dynamo.exc.Unsupported: Dynamic slicing with Tensor arguments
(EngineCore_DP0 pid=1671620)   Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
(EngineCore_DP0 pid=1671620)   Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
(EngineCore_DP0 pid=1671620) 
(EngineCore_DP0 pid=1671620)   Developer debug context: SliceVariable start: TensorVariable(), stop: TensorVariable(), step: ConstantVariable(NoneType: None)
(EngineCore_DP0 pid=1671620) 
(EngineCore_DP0 pid=1671620)  For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
(EngineCore_DP0 pid=1671620) 
(EngineCore_DP0 pid=1671620) from user code:
(EngineCore_DP0 pid=1671620)    File "/home/huachenheli/github/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 511, in forward
(EngineCore_DP0 pid=1671620)     x_attn = self.attn(
(EngineCore_DP0 pid=1671620)   File "/home/huachenheli/github/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 435, in forward
(EngineCore_DP0 pid=1671620)     q_i = q[:, start_idx:end_idx]
(EngineCore_DP0 pid=1671620) 
(EngineCore_DP0 pid=1671620) Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
(EngineCore_DP0 pid=1671620) 

cc. @Lucaskabela

Repro:
Command:

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-VL-7B-Instruct --port 8001 --host 0.0.0.0 --dtype bfloat16 --limit-mm-per-prompt '{"image": 1, "video":1}'

with forced SDPA backend in layer.py:

elif current_platform.is_cuda():
        return _Backend.TORCH_SDPA, None

vllm & torch versions:

(vllm) [huachenheli@devgpu039.dkl2 ~/github/vllm (disableqwen25compile)]$ uv pip show vllm
Using Python 3.12.11 environment at: /home/huachenheli/uv_env/vllm
Name: vllm
Version: 0.10.2rc3.dev1414+gd3ab240f3.d20251029.precompiled
Location: /home/huachenheli/uv_env/vllm/lib/python3.12/site-packages

(vllm) [huachenheli@devgpu039.dkl2 ~/github/vllm (disableqwen25compile)]$ uv pip show torch
Using Python 3.12.11 environment at: /home/huachenheli/uv_env/vllm
Name: torch
Version: 2.9.0
Location: /home/huachenheli/uv_env/vllm/lib/python3.12/site-packages

Test Plan

local vllm

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
@mergify mergify bot added the qwen Related to Qwen models label Oct 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a temporary fix for a torch.compile issue with the Qwen2.5 VL vision model. The change involves commenting out the @support_torch_compile decorator for the Qwen2_5_VisionBlock, which effectively disables compilation for this block and avoids the Unsupported: Dynamic slicing with Tensor arguments error. This is a reasonable and effective short-term solution to unblock users while a more permanent fix for the underlying issue is investigated. The change is correct and I approve it.

@huachenheli huachenheli changed the title [Temp fix] Disable torch.compile for Qwen2.5 VL temporarily. [Temp fix] Disable torch.compile for Qwen2.5 VL's VisionBlock temporarily. Oct 29, 2025
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2025
Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @huachenheli I tried testing this pretty extensively, but it is my first major feature work in vLLM so I am not surprised I missed something

That said, I never observed this error in my testing so think more context is needed on this PR.

Specifically:

  1. What versions (torch/vllm primarily) are you using?
  2. What command are you running to get this error?

Once those are provided on the PR, please re-ping me so I can get to work on fixing :) Thanks!

@huachenheli
Copy link
Contributor Author

cc @huachenheli I tried testing this pretty extensively, but it is my first major feature work in vLLM so I am not surprised I missed something

That said, I never observed this error in my testing so think more context is needed on this PR.

Specifically:

  1. What versions (torch/vllm primarily) are you using?
  2. What command are you running to get this error?

Once those are provided on the PR, please re-ping me so I can get to work on fixing :) Thanks!

Updated my PR description with more details. PTAL.

@ywang96
Copy link
Member

ywang96 commented Oct 29, 2025

with forced SDPA backend in layer.py:

You should be also able to do this without modifying the code by passing --mm-encoder-attn-backend TORCH_SDPA

@Lucaskabela
Copy link
Contributor

Lucaskabela commented Oct 29, 2025

I should have a fix ready pretty soon (within the hour) - the issue here is that compile doesn't support slices with tensors yet (but @laithsakka has a PR supporting this on nightly - see pytorch/pytorch#165074)

So for now, we can move this to a custom op, and once we upgrade torch version to include Laith's fix we can move this outside the custom op :)

@Lucaskabela
Copy link
Contributor

Please see #27764 @huachenheli @ywang96

@ywang96 ywang96 added this to the v0.11.1 milestone Oct 29, 2025
@Lucaskabela
Copy link
Contributor

That said we should land and cherrypick this PR into release - the compile integration on the VisionBlock specifically needs more hardening before we push it to general release

@ywang96 ywang96 enabled auto-merge (squash) October 29, 2025 19:37
@ywang96 ywang96 merged commit 48eb8eb into vllm-project:main Oct 29, 2025
53 checks passed
MatthewBonanni pushed a commit to MatthewBonanni/vllm that referenced this pull request Oct 30, 2025
…rily. (vllm-project#27760)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 30, 2025

@Lucaskabela @ywang96 I think we also need to a way to safe guard this as it requires a very new version of Pytorch with that specific PR to be able to handle the dynamic slicing.
Will we be able to implement a condition to fallback to use direct_register_custom_ops if the Pytorch does not contain that PR? This will be very useful as not all hardware platforms are able upgrade their Pytorch Version that quickly to include that pytorch bugfix PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants