Skip to content

Commit 0306e36

Browse files
committed
reset attention_factor to old behaviour
1 parent e0a7c80 commit 0306e36

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

fast_llm/layers/transformer/rotary/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ class YarnRotaryConfig(DefaultRotaryConfig):
127127
original_context_length: int = Field(default=8192, hint=FieldHint.feature)
128128

129129
def _validate(self) -> None:
130-
if self.attention_factor is None:
131-
# with self._set_implicit_default():
132-
self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0
130+
# if self.attention_factor is None:
131+
# # with self._set_implicit_default():
132+
# self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0
133133
super()._validate()
134134

135135
def _get_configurable_class(self) -> "type[YarnRotary]":

fast_llm/layers/transformer/rotary/rotary.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]):
181181
"""
182182

183183
def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor:
184-
return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor
184+
attention_factor = self._config.attention_factor
185+
if attention_factor is None:
186+
attention_factor = 0.1 * math.log(self._config.scale_factor) + 1.0
187+
return super()._get_frequencies(sequence_length, kv_channels, device) * attention_factor
185188

186189
def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor:
187190
scales = super()._get_angle_scales(kv_channels, device)

0 commit comments

Comments
 (0)