@@ -31,11 +31,16 @@ impl Qwen3Attention {
3131 }
3232
3333 let num_attention_heads = config. num_attention_heads ;
34- let attention_head_size = config. hidden_size / config. num_attention_heads ;
34+ let attention_head_size = config
35+ . head_dim
36+ . unwrap_or ( config. hidden_size / config. num_attention_heads ) ;
3537 let num_key_value_heads = config. num_key_value_heads ;
3638 let hidden_size = config. hidden_size ;
3739
38- let query_weight = vb. pp ( "q_proj" ) . get ( ( hidden_size, hidden_size) , "weight" ) ?;
40+ let query_weight = vb. pp ( "q_proj" ) . get (
41+ ( num_attention_heads * attention_head_size, hidden_size) ,
42+ "weight" ,
43+ ) ?;
3944 let query_bias = vb. pp ( "q_proj" ) . get ( hidden_size, "bias" ) ?;
4045 let q_proj = Linear :: new ( query_weight, Some ( query_bias) , None ) ;
4146
@@ -57,8 +62,10 @@ impl Qwen3Attention {
5762 . get ( num_key_value_heads * attention_head_size, "bias" ) ?;
5863 let v_proj = Linear :: new ( value_weight, Some ( value_bias) , None ) ;
5964
60- let o_proj_weight = vb. pp ( "o_proj" ) . get ( ( hidden_size, hidden_size) , "weight" ) ?;
61-
65+ let o_proj_weight = vb. pp ( "o_proj" ) . get (
66+ ( num_attention_heads * attention_head_size, hidden_size) ,
67+ "weight" ,
68+ ) ?;
6269 let o_proj = Linear :: new ( o_proj_weight, None , None ) ;
6370
6471 let q_norm = RMSNorm :: load ( vb. pp ( "q_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
0 commit comments