Skip to content

Commit c142ec1

Browse files
committed
Add dtype argument to precompute_freqs_cis
1 parent c955dac commit c142ec1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,15 @@ def forward(self, x: Tensor) -> Tensor:
222222

223223

224224
def precompute_freqs_cis(
225-
seq_len: int, n_elem: int, base: int = 10000
225+
seq_len: int, n_elem: int, base: int = 10000,
226+
dtype: torch.dtype = torch.bfloat16
226227
) -> Tensor:
227228
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
228229
t = torch.arange(seq_len, device=freqs.device)
229230
freqs = torch.outer(t, freqs)
230231
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
231232
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
232-
return cache.to(dtype=torch.bfloat16)
233+
return cache.to(dtype=dtype)
233234

234235

235236
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:

0 commit comments

Comments
 (0)