Skip to content

Commit 30df8b6

Browse files
authored
[DeepSeek] Fix some bugs for dsk-v3 (#9874)
* use fused rope * fix import
1 parent 0b26a02 commit 30df8b6

File tree

1 file changed

+13
-27
lines changed
  • paddlenlp/experimental/transformers/deepseek_v2

1 file changed

+13
-27
lines changed

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,10 @@ def __init__(
9292
* attn_factor
9393
)
9494

95-
cos_cache, sin_cache = self._compute_cos_sin_cache()
95+
cache = self._compute_cos_sin_cache()
9696

97-
self.cos_cache: paddle.Tensor
98-
self.register_buffer("cos_cache", cos_cache, persistable=True)
99-
self.sin_cache: paddle.Tensor
100-
self.register_buffer("sin_cache", sin_cache, persistable=True)
97+
self.cos_sin_cache: paddle.Tensor
98+
self.register_buffer("cos_sin_cache", cache, persistable=True)
10199

102100
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
103101
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
@@ -116,37 +114,25 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
116114
def _compute_cos_sin_cache(self) -> paddle.Tensor:
117115
inv_freq = self._compute_inv_freq(self.scaling_factor)
118116
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
119-
120-
freqs = paddle.outer(t, inv_freq)
121-
emb = paddle.concat((freqs, freqs), axis=-1)
122-
cos = emb.cos() * self.mscale
123-
sin = emb.sin() * self.mscale
124-
125-
return cos.cast(self._dtype), sin.cast(self._dtype)
117+
freqs = paddle.einsum("i,j->ij", t, inv_freq)
118+
cos = freqs.cos() * self.mscale
119+
sin = freqs.sin() * self.mscale
120+
cache = paddle.concat((cos, sin), axis=-1)
121+
return cache.cast(self._dtype)
126122

127123
def forward(
128124
self,
129125
position_ids: paddle.Tensor,
130126
query: paddle.Tensor,
131127
key: paddle.Tensor,
132128
) -> Tuple[paddle.Tensor, paddle.Tensor]:
133-
cos = self.cos_cache[position_ids].unsqueeze(1)
134-
sin = self.sin_cache[position_ids].unsqueeze(1)
135-
136-
def rotate_half(x):
137-
"""Rotates half the hidden axiss of the input."""
138-
x1 = x[..., : x.shape[-1] // 2]
139-
x2 = x[..., x.shape[-1] // 2 :]
140-
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
141-
142-
s, h, d = query.shape
143-
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
129+
import os
144130

145-
s, h, d = key.shape
146-
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
131+
from paddlenlp_ops import fused_rotary_position_encoding
147132

148-
query = (query * cos) + (rotate_half(query) * sin)
149-
key = (key * cos) + (rotate_half(key) * sin)
133+
# In-place operations that update the query and key tensors.
134+
os.environ["stride_in_no_check_dy2st_diff"] = "1"
135+
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
150136

151137
return query, key
152138

0 commit comments

Comments
 (0)