-
Notifications
You must be signed in to change notification settings - Fork 461
[5/N][refactor]add torchair rotary ops #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the torchair-related rotary embedding operations by moving them into a dedicated torchair/ops
directory, which improves code organization. The changes also involve monkey-patching RotaryEmbedding
and DeepseekScalingRotaryEmbedding
for torchair-specific implementations. I've identified a critical bug in the rotary embedding cache resizing logic where the new sequence length was ignored. This could lead to incorrect computations for sequences longer than the pre-allocated cache size. A fix is suggested to address this issue.
def __set_cos_sin_cache(self, seq_len, device, dtype): | ||
inv_freq = 1.0 / (self.base**(torch.arange( | ||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) * | ||
(1 / self.rotary_dim))) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
t = torch.arange(self.max_position_embeddings, | ||
device=self.inv_freq.device, | ||
dtype=torch.float32) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
|
||
emb = torch.cat((freqs, freqs), dim=-1) | ||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) | ||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) | ||
self.embed = F.embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seq_len
parameter in __set_cos_sin_cache
is ignored. At line 258, torch.arange
uses self.max_position_embeddings
instead of the provided seq_len
to create the t
tensor. This means that when this function is called from rope_forward
to extend the cache for a sequence length greater than self.max_position_embeddings
, the cache is not actually resized. This will lead to incorrect rotary embeddings for longer sequences.
To fix this, seq_len
should be used to create the t
tensor, and self.max_position_embeddings
should be updated to reflect the new cache size.
def __set_cos_sin_cache(self, seq_len, device, dtype): | |
inv_freq = 1.0 / (self.base**(torch.arange( | |
0, self.rotary_dim, 2, device=device, dtype=torch.float32) * | |
(1 / self.rotary_dim))) | |
self.register_buffer("inv_freq", inv_freq) | |
t = torch.arange(self.max_position_embeddings, | |
device=self.inv_freq.device, | |
dtype=torch.float32) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) | |
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) | |
self.embed = F.embedding | |
def __set_cos_sin_cache(self, seq_len, device, dtype): | |
self.max_position_embeddings = seq_len | |
inv_freq = 1.0 / (self.base**(torch.arange( | |
0, self.rotary_dim, 2, device=device, dtype=torch.float32) * | |
(1 / self.rotary_dim))) | |
self.register_buffer("inv_freq", inv_freq) | |
t = torch.arange(seq_len, | |
device=self.inv_freq.device, | |
dtype=torch.float32) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) | |
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) | |
self.embed = F.embedding |
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
ea23daf
to
582103b
Compare
6b2ecaf
to
7a5ec68
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
7a5ec68
to
0c38a54
Compare
3ca0ff2
to
fb891e0
Compare
92b9dfe
to
f48178b
Compare
The CI failed at isort, but use pre-commit to fix locally and the CI has no err message. this PR may don 't have problems in lint, but blocked by some CI problems. |
aeadd7d
to
1dfec63
Compare
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (72.39%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #2559 +/- ##
==========================================
- Coverage 73.03% 73.02% -0.01%
==========================================
Files 149 151 +2
Lines 21515 21841 +326
==========================================
+ Hits 15714 15950 +236
- Misses 5801 5891 +90
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
f5b38da
to
bb1400e
Compare
Signed-off-by: hust17yixuan <303660421@qq.com>
bb1400e
to
c4323f4
Compare
### What this PR does / why we need it? Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: vllm-project/vllm@ab9f2cf - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@81eea3d Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>
### What this PR does / why we need it? Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: vllm-project/vllm@ab9f2cf - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@81eea3d Signed-off-by: hust17yixuan <303660421@qq.com>
### What this PR does / why we need it? Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: vllm-project/vllm@ab9f2cf - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@81eea3d Signed-off-by: hust17yixuan <303660421@qq.com>
What this PR does / why we need it?
Move torchair related rotary ops into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair rotary ops.
Does this PR introduce any user-facing change?
No
How was this patch tested?
vLLM version: main
vLLM main: vllm-project/vllm@ab9f2cf