Skip to content

Commit 2dc0d4e

Browse files
RaymondLi0oleksost
andauthored
Yarn (#145)
Co-authored-by: oleksost <ostapy2@gmail.com>
1 parent de7b2d8 commit 2dc0d4e

File tree

4 files changed

+100
-8
lines changed

4 files changed

+100
-8
lines changed

fast_llm/data/dataset/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class
99
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
1010
from fast_llm.engine.distributed.config import PhaseType
11-
from fast_llm.utils import Assert
11+
from fast_llm.utils import Assert, normalize_probabilities
1212

1313
if typing.TYPE_CHECKING:
1414
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
@@ -204,6 +204,7 @@ class BlendedDatasetConfig(SampledDatasetConfig):
204204
)
205205

206206
def _validate(self) -> None:
207+
self.weights = normalize_probabilities(self.weights)
207208
super()._validate()
208209
Assert.geq(len(self.datasets), 2)
209210
Assert.eq(len(self.datasets), len(self.weights))

fast_llm/layers/transformer/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class RotaryEmbeddingType(str, enum.Enum):
8080
none = "none"
8181
default = "default"
8282
llama3 = "llama3"
83+
yarn = "yarn"
8384

8485

8586
@config_class()
@@ -110,7 +111,22 @@ class RotaryArchitectureConfig(BaseModelArchitectureConfig):
110111
default=4.0, desc="High frequency factor for llama3-type scaling.", hint=FieldHint.feature
111112
)
112113
original_context_length: int = Field(
113-
default=8192, desc="Original context length for llama3-type scaling.", hint=FieldHint.feature
114+
default=8192, desc="Original context length for llama3/yarn-type scaling.", hint=FieldHint.feature
115+
)
116+
attention_factor: None | float = Field(
117+
default=None,
118+
desc="Attention factor for yarn-type scaling.",
119+
hint=FieldHint.feature,
120+
)
121+
beta_fast: float = Field(
122+
default=32.,
123+
desc="Beta-fast for yarn-type scaling.",
124+
hint=FieldHint.feature,
125+
)
126+
beta_slow: float = Field(
127+
default=1.,
128+
desc="Beta-slow for yarn-type scaling.",
129+
hint=FieldHint.feature,
114130
)
115131

116132
@property

fast_llm/layers/transformer/preprocessing.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,67 @@ def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tor
3737
config.high_frequency_factor - config.low_frequency_factor
3838
)
3939
new_frequencies.append((1 - smooth) * frequency / config.scale_factor + smooth * frequency)
40-
return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device)
40+
return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0
41+
42+
43+
def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels, sequence_length) -> torch.Tensor:
44+
"""
45+
Yarn scaling:
46+
https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163
47+
[original paper](https://arxiv.org/abs/2309.00071)
48+
"""
49+
base = config.theta
50+
partial_rotary_factor = 1.0
51+
dim = int(kv_channels * partial_rotary_factor)
52+
max_position_embeddings = sequence_length
53+
factor = config.scale_factor
54+
55+
attention_factor = config.attention_factor
56+
if attention_factor is None:
57+
attention_factor = 0.1 * math.log(factor) + 1.0
58+
59+
# Compute the inverse frequencies
60+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
61+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
62+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
63+
64+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
65+
"""Find dimension range bounds based on rotations"""
66+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
67+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
68+
return max(low, 0), min(high, dim - 1)
69+
70+
def linear_ramp_factor(min, max, dim):
71+
if min == max:
72+
max += 0.001 # Prevent singularity
73+
74+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
75+
ramp_func = torch.clamp(linear_func, 0, 1)
76+
return ramp_func
77+
78+
79+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
80+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
81+
# pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim)
82+
# inv_freq_extrapolation = 1.0 / pos_freqs
83+
# inv_freq_interpolation = 1.0 / (factor * pos_freqs)
84+
85+
inv_freq_extrapolation = frequencies
86+
inv_freq_interpolation = frequencies / factor
87+
88+
# TODO: max_position_embeddings or original_context_length?
89+
# see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304
90+
low, high = find_correction_range(config.beta_fast, config.beta_slow, dim, base, config.original_context_length)
91+
92+
# Get n-dimensional rotational scaling corrected for extrapolation
93+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(frequencies.device)
94+
inv_freq = (
95+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
96+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
97+
)
98+
99+
return inv_freq, attention_factor
100+
41101

42102

43103
def get_rotary_frequencies(
@@ -56,13 +116,19 @@ def get_rotary_frequencies(
56116
frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64)
57117
# Apply scaling
58118
if config.type == RotaryEmbeddingType.llama3:
59-
frequencies = apply_llama3_scaling(config, frequencies)
119+
frequencies, attention_scaling = apply_llama3_scaling(config, frequencies)
120+
elif config.type == RotaryEmbeddingType.yarn:
121+
frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels, sequence_length)
122+
else:
123+
attention_scaling = 1.0
60124
angles = torch.outer(positions, frequencies)
61125
frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64)
62126
if not config.complex_format:
63127
frequencies = convert_rotary_complex_to_real(
64128
torch.view_as_real(frequencies).flatten(-2), kv_channels, 3
65129
).contiguous()
130+
# Advanced Rope types like yarn apply a post-processing scaling factor, equivalent to scaling attention.
131+
frequencies = frequencies * attention_scaling
66132
return frequencies
67133

68134

fast_llm/models/gpt/conversion.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,13 @@ class RopeScalingParamConverter(ParamConverter):
294294
"low_freq_factor",
295295
"high_freq_factor",
296296
"original_max_position_embeddings",
297+
"attention_factor",
298+
"beta_fast",
299+
"beta_slow",
297300
)
298301

299302
def __post_init__(self):
300-
Assert.eq(len(self.fast_llm_names), 5)
303+
Assert.eq(len(self.fast_llm_names), 8)
301304
Assert.eq(len(self.export_names), 1)
302305

303306
def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
@@ -306,16 +309,19 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing
306309
return (None,)
307310
elif rope_type == RotaryEmbeddingType.llama3:
308311
return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},)
312+
elif rope_type == RotaryEmbeddingType.yarn:
313+
return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("yarn", *parameters), strict=True)},)
309314
else:
310315
raise ValueError(f"Unsupported rotary scaling type: {rope_type}")
311316

312317
def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
313318
(export_value,) = export_values
314319
if export_value is None or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default":
315-
return (RotaryEmbeddingType.default,) + (DEFAULT,) * 4
320+
return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7
316321
elif rope_type == RotaryEmbeddingType.llama3:
317-
# TODO: Is it safe to assume all values are provided?
318-
return ("llama3", *[export_value[key] for key in self._HUGGINGFACE_NAMES[1:]])
322+
return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]])
323+
elif rope_type == RotaryEmbeddingType.yarn:
324+
return ("yarn", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]])
319325
else:
320326
raise ValueError(f"Unsupported rotary scaling type: {rope_type}")
321327

@@ -337,6 +343,9 @@ def _create_config_converters(cls) -> list[ParamConverter]:
337343
("transformer", "rotary", "low_frequency_factor"),
338344
("transformer", "rotary", "high_frequency_factor"),
339345
("transformer", "rotary", "original_context_length"),
346+
("transformer", "rotary", "attention_factor"),
347+
("transformer", "rotary", "beta_fast"),
348+
("transformer", "rotary", "beta_slow"),
340349
),
341350
export_names=(("rope_scaling",),),
342351
),

0 commit comments

Comments
 (0)