File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -222,14 +222,15 @@ def forward(self, x: Tensor) -> Tensor:
222
222
223
223
224
224
def precompute_freqs_cis (
225
- seq_len : int , n_elem : int , base : int = 10000
225
+ seq_len : int , n_elem : int , base : int = 10000 ,
226
+ dtype : torch .dtype = torch .bfloat16
226
227
) -> Tensor :
227
228
freqs = 1.0 / (base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem ))
228
229
t = torch .arange (seq_len , device = freqs .device )
229
230
freqs = torch .outer (t , freqs )
230
231
freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
231
232
cache = torch .stack ([freqs_cis .real , freqs_cis .imag ], dim = - 1 )
232
- return cache .to (dtype = torch . bfloat16 )
233
+ return cache .to (dtype = dtype )
233
234
234
235
235
236
def apply_rotary_emb (x : Tensor , freqs_cis : Tensor ) -> Tensor :
You can’t perform that action at this time.
0 commit comments