Skip to content

Conversation

hust17yixuan
Copy link
Contributor

@hust17yixuan hust17yixuan commented Aug 26, 2025

What this PR does / why we need it?

Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops.

Does this PR introduce any user-facing change?

No

How was this patch tested?

vLLM version: main
vLLM main: vllm-project/vllm@ab9f2cf

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the torchair-related rotary embedding operations by moving them into a dedicated torchair/ops directory, which improves code organization. The changes also involve monkey-patching RotaryEmbedding and DeepseekScalingRotaryEmbedding for torchair-specific implementations. I've identified a critical bug in the rotary embedding cache resizing logic where the new sequence length was ignored. This could lead to incorrect computations for sequences longer than the pre-allocated cache size. A fix is suggested to address this issue.

Comment on lines +252 to +266
def __set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The seq_len parameter in __set_cos_sin_cache is ignored. At line 258, torch.arange uses self.max_position_embeddings instead of the provided seq_len to create the t tensor. This means that when this function is called from rope_forward to extend the cache for a sequence length greater than self.max_position_embeddings, the cache is not actually resized. This will lead to incorrect rotary embeddings for longer sequences.

To fix this, seq_len should be used to create the t tensor, and self.max_position_embeddings should be updated to reflect the new cache size.

Suggested change
def __set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
def __set_cos_sin_cache(self, seq_len, device, dtype):
self.max_position_embeddings = seq_len
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(seq_len,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@hust17yixuan hust17yixuan force-pushed the rope_ops_torchair branch 3 times, most recently from 6b2ecaf to 7a5ec68 Compare August 27, 2025 01:17
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@hust17yixuan hust17yixuan force-pushed the rope_ops_torchair branch 5 times, most recently from 3ca0ff2 to fb891e0 Compare August 27, 2025 13:46
@wangxiyuan wangxiyuan added the ready read for review label Aug 28, 2025
@hust17yixuan hust17yixuan force-pushed the rope_ops_torchair branch 2 times, most recently from 92b9dfe to f48178b Compare August 28, 2025 09:26
@hust17yixuan
Copy link
Contributor Author

The CI failed at isort, but use pre-commit to fix locally and the CI has no err message. this PR may don 't have problems in lint, but blocked by some CI problems.

@hust17yixuan hust17yixuan force-pushed the rope_ops_torchair branch 2 times, most recently from aeadd7d to 1dfec63 Compare August 29, 2025 02:40
Copy link

codecov bot commented Aug 29, 2025

Codecov Report

❌ Patch coverage is 72.39264% with 90 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.02%. Comparing base (3a5fc5e) to head (c4323f4).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...m_ascend/torchair/ops/torchair_rotary_embedding.py 42.06% 84 Missing ⚠️
vllm_ascend/torchair/utils.py 14.28% 6 Missing ⚠️

❌ Your patch status has failed because the patch coverage (72.39%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2559      +/-   ##
==========================================
- Coverage   73.03%   73.02%   -0.01%     
==========================================
  Files         149      151       +2     
  Lines       21515    21841     +326     
==========================================
+ Hits        15714    15950     +236     
- Misses       5801     5891      +90     
Flag Coverage Δ
unittests 73.02% <72.39%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@hust17yixuan hust17yixuan force-pushed the rope_ops_torchair branch 5 times, most recently from f5b38da to bb1400e Compare August 30, 2025 07:58
Signed-off-by: hust17yixuan <303660421@qq.com>
@wangxiyuan wangxiyuan merged commit c2c97f3 into vllm-project:main Sep 1, 2025
26 of 28 checks passed
wenba0 pushed a commit to wenba0/vllm-ascend that referenced this pull request Sep 5, 2025
### What this PR does / why we need it?
Move torchair related rotary ops into torchair dir to make the code
clear. Next step we'll remove all torchair related code outside of
torchair rotary ops.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: main
vLLM main:
vllm-project/vllm@ab9f2cf

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@81eea3d

Signed-off-by: hust17yixuan <303660421@qq.com>
Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
Move torchair related rotary ops into torchair dir to make the code
clear. Next step we'll remove all torchair related code outside of
torchair rotary ops.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: main
vLLM main:
vllm-project/vllm@ab9f2cf


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@81eea3d

Signed-off-by: hust17yixuan <303660421@qq.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
Move torchair related rotary ops into torchair dir to make the code
clear. Next step we'll remove all torchair related code outside of
torchair rotary ops.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: main
vLLM main:
vllm-project/vllm@ab9f2cf


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@81eea3d

Signed-off-by: hust17yixuan <303660421@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:tests ready read for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants