@@ -92,12 +92,10 @@ def __init__(
92
92
* attn_factor
93
93
)
94
94
95
- cos_cache , sin_cache = self ._compute_cos_sin_cache ()
95
+ cache = self ._compute_cos_sin_cache ()
96
96
97
- self .cos_cache : paddle .Tensor
98
- self .register_buffer ("cos_cache" , cos_cache , persistable = True )
99
- self .sin_cache : paddle .Tensor
100
- self .register_buffer ("sin_cache" , sin_cache , persistable = True )
97
+ self .cos_sin_cache : paddle .Tensor
98
+ self .register_buffer ("cos_sin_cache" , cache , persistable = True )
101
99
102
100
def _compute_inv_freq (self , scaling_factor : float ) -> paddle .Tensor :
103
101
pos_freqs = self .base ** (paddle .arange (0 , self .rotary_dim , 2 , dtype = paddle .float32 ) / self .rotary_dim )
@@ -116,37 +114,25 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
116
114
def _compute_cos_sin_cache (self ) -> paddle .Tensor :
117
115
inv_freq = self ._compute_inv_freq (self .scaling_factor )
118
116
t = paddle .arange (self .max_position_embeddings * self .scaling_factor , dtype = paddle .float32 )
119
-
120
- freqs = paddle .outer (t , inv_freq )
121
- emb = paddle .concat ((freqs , freqs ), axis = - 1 )
122
- cos = emb .cos () * self .mscale
123
- sin = emb .sin () * self .mscale
124
-
125
- return cos .cast (self ._dtype ), sin .cast (self ._dtype )
117
+ freqs = paddle .einsum ("i,j->ij" , t , inv_freq )
118
+ cos = freqs .cos () * self .mscale
119
+ sin = freqs .sin () * self .mscale
120
+ cache = paddle .concat ((cos , sin ), axis = - 1 )
121
+ return cache .cast (self ._dtype )
126
122
127
123
def forward (
128
124
self ,
129
125
position_ids : paddle .Tensor ,
130
126
query : paddle .Tensor ,
131
127
key : paddle .Tensor ,
132
128
) -> Tuple [paddle .Tensor , paddle .Tensor ]:
133
- cos = self .cos_cache [position_ids ].unsqueeze (1 )
134
- sin = self .sin_cache [position_ids ].unsqueeze (1 )
135
-
136
- def rotate_half (x ):
137
- """Rotates half the hidden axiss of the input."""
138
- x1 = x [..., : x .shape [- 1 ] // 2 ]
139
- x2 = x [..., x .shape [- 1 ] // 2 :]
140
- return paddle .concat ([- x2 , x1 ], axis = - 1 ) # shape is the same as x
141
-
142
- s , h , d = query .shape
143
- query = query .reshape ([s , h , d // 2 , 2 ]).transpose ([0 , 1 , 3 , 2 ]).reshape ([s , h , d ])
129
+ import os
144
130
145
- s , h , d = key .shape
146
- key = key .reshape ([s , h , d // 2 , 2 ]).transpose ([0 , 1 , 3 , 2 ]).reshape ([s , h , d ])
131
+ from paddlenlp_ops import fused_rotary_position_encoding
147
132
148
- query = (query * cos ) + (rotate_half (query ) * sin )
149
- key = (key * cos ) + (rotate_half (key ) * sin )
133
+ # In-place operations that update the query and key tensors.
134
+ os .environ ["stride_in_no_check_dy2st_diff" ] = "1"
135
+ fused_rotary_position_encoding (query , key , position_ids , self .cos_sin_cache , self .rotary_dim , False )
150
136
151
137
return query , key
152
138
0 commit comments