@@ -37,7 +37,67 @@ def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tor
37
37
config .high_frequency_factor - config .low_frequency_factor
38
38
)
39
39
new_frequencies .append ((1 - smooth ) * frequency / config .scale_factor + smooth * frequency )
40
- return torch .tensor (new_frequencies , dtype = frequencies .dtype , device = frequencies .device )
40
+ return torch .tensor (new_frequencies , dtype = frequencies .dtype , device = frequencies .device ), 1.0
41
+
42
+
43
+ def apply_yarn_scaling (config : RotaryConfig , frequencies : torch .Tensor , kv_channels , sequence_length ) -> torch .Tensor :
44
+ """
45
+ Yarn scaling:
46
+ https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163
47
+ [original paper](https://arxiv.org/abs/2309.00071)
48
+ """
49
+ base = config .theta
50
+ partial_rotary_factor = 1.0
51
+ dim = int (kv_channels * partial_rotary_factor )
52
+ max_position_embeddings = sequence_length
53
+ factor = config .scale_factor
54
+
55
+ attention_factor = config .attention_factor
56
+ if attention_factor is None :
57
+ attention_factor = 0.1 * math .log (factor ) + 1.0
58
+
59
+ # Compute the inverse frequencies
60
+ def find_correction_dim (num_rotations , dim , base , max_position_embeddings ):
61
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
62
+ return (dim * math .log (max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
63
+
64
+ def find_correction_range (low_rot , high_rot , dim , base , max_position_embeddings ):
65
+ """Find dimension range bounds based on rotations"""
66
+ low = math .floor (find_correction_dim (low_rot , dim , base , max_position_embeddings ))
67
+ high = math .ceil (find_correction_dim (high_rot , dim , base , max_position_embeddings ))
68
+ return max (low , 0 ), min (high , dim - 1 )
69
+
70
+ def linear_ramp_factor (min , max , dim ):
71
+ if min == max :
72
+ max += 0.001 # Prevent singularity
73
+
74
+ linear_func = (torch .arange (dim , dtype = torch .float32 ) - min ) / (max - min )
75
+ ramp_func = torch .clamp (linear_func , 0 , 1 )
76
+ return ramp_func
77
+
78
+
79
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
80
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
81
+ # pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim)
82
+ # inv_freq_extrapolation = 1.0 / pos_freqs
83
+ # inv_freq_interpolation = 1.0 / (factor * pos_freqs)
84
+
85
+ inv_freq_extrapolation = frequencies
86
+ inv_freq_interpolation = frequencies / factor
87
+
88
+ # TODO: max_position_embeddings or original_context_length?
89
+ # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304
90
+ low , high = find_correction_range (config .beta_fast , config .beta_slow , dim , base , config .original_context_length )
91
+
92
+ # Get n-dimensional rotational scaling corrected for extrapolation
93
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).float ().to (frequencies .device )
94
+ inv_freq = (
95
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
96
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
97
+ )
98
+
99
+ return inv_freq , attention_factor
100
+
41
101
42
102
43
103
def get_rotary_frequencies (
@@ -56,13 +116,19 @@ def get_rotary_frequencies(
56
116
frequencies = config .theta ** - torch .arange (0 , 1 , 2 / kv_channels , device = device , dtype = torch .float64 )
57
117
# Apply scaling
58
118
if config .type == RotaryEmbeddingType .llama3 :
59
- frequencies = apply_llama3_scaling (config , frequencies )
119
+ frequencies , attention_scaling = apply_llama3_scaling (config , frequencies )
120
+ elif config .type == RotaryEmbeddingType .yarn :
121
+ frequencies , attention_scaling = apply_yarn_scaling (config , frequencies , kv_channels , sequence_length )
122
+ else :
123
+ attention_scaling = 1.0
60
124
angles = torch .outer (positions , frequencies )
61
125
frequencies = torch .polar (torch .ones_like (angles ), angles )[None , :, None , :].to (torch .complex64 )
62
126
if not config .complex_format :
63
127
frequencies = convert_rotary_complex_to_real (
64
128
torch .view_as_real (frequencies ).flatten (- 2 ), kv_channels , 3
65
129
).contiguous ()
130
+ # Advanced Rope types like yarn apply a post-processing scaling factor, equivalent to scaling attention.
131
+ frequencies = frequencies * attention_scaling
66
132
return frequencies
67
133
68
134
0 commit comments