11use crate :: flash_attn:: flash_attn_varlen;
2- use crate :: layers:: { HiddenAct , LayerNorm , Linear } ;
3- use crate :: models:: { GTEConfig , Model , NTKScaling , PositionEmbeddingType , RopeScaling } ;
2+ use crate :: layers:: { get_cos_sin , get_inv_freqs , LayerNorm , Linear } ;
3+ use crate :: models:: { GTEClassificationHead , GTEConfig , Model , PositionEmbeddingType , GTEMLP } ;
44use candle:: { DType , Device , IndexOp , Result , Tensor } ;
55use candle_nn:: { Embedding , Module , VarBuilder } ;
6+ use candle_rotary:: apply_rotary_inplace;
67use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
78
89struct GTEAttention {
@@ -72,7 +73,7 @@ impl GTEAttention {
7273 let k = qkv. narrow ( 1 , self . num_attention_heads , self . num_attention_heads ) ?;
7374 let v = qkv. narrow ( 1 , self . num_attention_heads * 2 , self . num_attention_heads ) ?;
7475
75- candle_rotary :: apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
76+ apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
7677
7778 let attention = flash_attn_varlen (
7879 & q,
@@ -93,60 +94,7 @@ impl GTEAttention {
9394 }
9495}
9596
96- struct GTEMLP {
97- up_gate_proj : Linear ,
98- down_proj : Linear ,
99-
100- act : HiddenAct ,
101- intermediate_size : usize ,
102-
103- span : tracing:: Span ,
104- }
105-
106- impl GTEMLP {
107- pub fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
108- let intermediate_size = config. intermediate_size ;
109-
110- let up_gate_proj_weight = vb
111- . pp ( "up_gate_proj" )
112- . get ( ( intermediate_size * 2 , config. hidden_size ) , "weight" ) ?;
113-
114- let up_gate_proj = Linear :: new ( up_gate_proj_weight, None , None ) ;
115-
116- let down_proj_weight = vb
117- . pp ( "down_proj" )
118- . get ( ( config. hidden_size , intermediate_size) , "weight" ) ?;
119- let down_proj_bias = vb. pp ( "down_proj" ) . get ( config. hidden_size , "bias" ) ?;
120- let down_proj = Linear :: new ( down_proj_weight, Some ( down_proj_bias) , None ) ;
121-
122- Ok ( Self {
123- up_gate_proj,
124- down_proj,
125- intermediate_size,
126- act : config. hidden_act . clone ( ) ,
127- span : tracing:: span!( tracing:: Level :: TRACE , "mlp" ) ,
128- } )
129- }
130-
131- pub fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
132- let _enter = self . span . enter ( ) ;
133-
134- let up_gate_states = self . up_gate_proj . forward ( hidden_states) ?;
135- let up_states = up_gate_states. narrow ( 1 , 0 , self . intermediate_size ) ?;
136- let gate_states =
137- up_gate_states. narrow ( 1 , self . intermediate_size , self . intermediate_size ) ?;
138-
139- let gate_states = match self . act {
140- HiddenAct :: Gelu => gate_states. gelu ( ) ,
141- HiddenAct :: Relu => gate_states. relu ( ) ,
142- HiddenAct :: Swiglu => gate_states. silu ( ) ,
143- } ?;
144- let r = self . down_proj . forward ( & ( gate_states * up_states) ?) ;
145- r
146- }
147- }
148-
149- struct GTELayer {
97+ pub struct GTELayer {
15098 attention : GTEAttention ,
15199 mlp : GTEMLP ,
152100 attention_layer_norm : LayerNorm ,
@@ -198,58 +146,6 @@ impl GTELayer {
198146 }
199147}
200148
201- pub struct GTEClassificationHead {
202- pooler : Option < Linear > ,
203- classifier : Linear ,
204- span : tracing:: Span ,
205- }
206-
207- impl GTEClassificationHead {
208- #[ allow( dead_code) ]
209- pub ( crate ) fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
210- let n_classes = match & config. id2label {
211- None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
212- Some ( id2label) => id2label. len ( ) ,
213- } ;
214-
215- let pooler = if let Ok ( pooler_weight) = vb
216- . pp ( "pooler.dense" )
217- . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
218- {
219- let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
220- Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
221- } else {
222- None
223- } ;
224-
225- let classifier_weight = vb
226- . pp ( "classifier" )
227- . get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
228- let classifier_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
229- let classifier = Linear :: new ( classifier_weight, Some ( classifier_bias) , None ) ;
230-
231- Ok ( Self {
232- classifier,
233- pooler,
234- span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
235- } )
236- }
237-
238- pub ( crate ) fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
239- let _enter = self . span . enter ( ) ;
240-
241- let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
242- if let Some ( pooler) = self . pooler . as_ref ( ) {
243- hidden_states = pooler. forward ( & hidden_states) ?;
244- hidden_states = hidden_states. tanh ( ) ?;
245- }
246-
247- let hidden_states = self . classifier . forward ( & hidden_states) ?;
248- let hidden_states = hidden_states. squeeze ( 1 ) ?;
249- Ok ( hidden_states)
250- }
251- }
252-
253149pub struct FlashGTEModel {
254150 word_embeddings : Embedding ,
255151 token_type_embeddings : Option < Embedding > ,
@@ -322,24 +218,19 @@ impl FlashGTEModel {
322218 config. layer_norm_eps ,
323219 ) ?;
324220
325- let inv_freqs = if let Some ( RopeScaling :: Ntk ( NTKScaling { factor } ) ) = config. rope_scaling {
326- let inv_freqs = candle_rotary:: inv_freqs (
327- layers[ 0 ] . attention . attention_head_size ,
328- config. rope_theta * factor,
329- vb. device ( ) ,
330- ) ?;
331- let s = factor. powf ( 2.0 / layers[ 0 ] . attention . attention_head_size as f32 ) as f64 ;
332- inv_freqs / s
333- } else {
334- candle_rotary:: inv_freqs (
335- layers[ 0 ] . attention . attention_head_size ,
336- config. rope_theta ,
337- vb. device ( ) ,
338- )
339- } ?;
340-
341- let ( cos_cache, sin_cache) =
342- candle_rotary:: cos_sin ( config. max_position_embeddings , & inv_freqs, vb. dtype ( ) ) ?;
221+ let inv_freqs = get_inv_freqs (
222+ layers[ 0 ] . attention . attention_head_size ,
223+ config. rope_theta ,
224+ vb. device ( ) ,
225+ config. rope_scaling . as_ref ( ) ,
226+ ) ?;
227+
228+ let ( cos_cache, sin_cache) = get_cos_sin (
229+ config. max_position_embeddings ,
230+ & inv_freqs,
231+ vb. dtype ( ) ,
232+ false ,
233+ ) ?;
343234
344235 Ok ( Self {
345236 word_embeddings,
0 commit comments