3737from paddlenlp .transformers .linear_utils import Linear
3838from paddlenlp .transformers .model_outputs import BaseModelOutputWithPast , ModelOutput
3939from paddlenlp .transformers .model_utils import PretrainedModel
40+ from paddlenlp .utils .tools import get_env_device
4041
4142from paddlemix .models .flash_attn_utils import (
4243 create_attention_module ,
4849from .bert_padding import index_first_axis , pad_input , unpad_input
4950from .configuration_qwen2_vl import Qwen2VLConfig , Qwen2VLVisionConfig
5051
52+ try :
53+ from paddle .incubate .nn .functional import fused_rotary_position_embedding
54+ except ImportError :
55+ fused_rotary_position_embedding = None
56+
5157logger = logging .get_logger (__name__ )
5258
5359flash_attn_func , flash_attn_varlen_func = has_flash_attn_func ()
@@ -407,7 +413,12 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) ->
407413 sin = freqs .sin ()
408414 cos = cos .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
409415 sin = sin .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
410- output = tensor * cos + rotate_half (tensor ) * sin
416+ if get_env_device () == "xpu" and fused_rotary_position_embedding is not None :
417+ output , _ , _ = fused_rotary_position_embedding (
418+ tensor , sin = sin , cos = cos , use_neox_rotary_style = False
419+ )
420+ else :
421+ output = tensor * cos + rotate_half (tensor ) * sin
411422 output = paddle .cast (output , orig_dtype )
412423 return output
413424
@@ -463,6 +474,12 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N
463474 nn .GELU (),
464475 nn .Linear (self .hidden_size , dim ),
465476 )
477+ if get_env_device () == "xpu" :
478+ self .mlp = nn .Sequential (
479+ Linear (self .hidden_size , self .hidden_size ),
480+ nn .GELU (),
481+ Linear (self .hidden_size , dim ),
482+ )
466483
467484 def forward (self , x : paddle .Tensor ) -> paddle .Tensor :
468485 x = self .mlp (self .ln_q (x ).reshape ([- 1 , self .hidden_size ]))
@@ -475,6 +492,9 @@ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
475492 self .fc1 = nn .Linear (dim , hidden_dim )
476493 self .act = ACT2FN [hidden_act ]
477494 self .fc2 = nn .Linear (hidden_dim , dim )
495+ if get_env_device () == "xpu" :
496+ self .fc1 = Linear (dim , hidden_dim )
497+ self .fc2 = Linear (hidden_dim , dim )
478498
479499 def forward (self , x ) -> paddle .Tensor :
480500 return self .fc2 (self .act (self .fc1 (x )))
@@ -486,6 +506,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
486506 self .num_heads = num_heads
487507 self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
488508 self .proj = nn .Linear (dim , dim )
509+ if get_env_device () == "xpu" :
510+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
511+ self .proj = Linear (dim , dim )
489512 self .head_dim = dim // num_heads # must added
490513
491514 def forward (
@@ -525,6 +548,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
525548 self .num_heads = num_heads
526549 self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
527550 self .proj = nn .Linear (dim , dim )
551+ if get_env_device () == "xpu" :
552+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
553+ self .proj = Linear (dim , dim )
528554 self .head_dim = dim // num_heads # must added
529555
530556 def forward (
@@ -657,6 +683,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
657683 self .variance_epsilon = eps
658684
659685 def forward (self , hidden_states ):
686+ if get_env_device () == "xpu" :
687+ try :
688+ import paddle_xpu_nn # noqa: F821
689+
690+ return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
691+ except ImportError :
692+ raise NotImplementedError (
693+ f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature"
694+ )
660695 if paddle .in_dynamic_mode ():
661696 with paddle .amp .auto_cast (False ):
662697 variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
@@ -1193,7 +1228,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel):
11931228
11941229 def _init_weights (self , layer ):
11951230 std = 0.2
1196- if isinstance (layer , (nn .Linear , nn .Conv3D )):
1231+ if isinstance (layer , (nn .Linear , nn .Conv3D , Linear )):
11971232 nn .initializer .Normal (mean = 0.0 , std = std )(layer .weight )
11981233 if layer .bias is not None :
11991234 nn .initializer .Constant (0.0 )(layer .bias )
@@ -1558,6 +1593,9 @@ def __init__(self, config, embedding_weights=None, transpose_y=False):
15581593 shape = [config .hidden_size , vocab_size ],
15591594 dtype = paddle .get_default_dtype (),
15601595 )
1596+ if get_env_device () == "xpu" :
1597+ import paddle_xpu .layers .nn .linear as xpu_linear
1598+ self .xpu_parallel_matmul = xpu_linear .parallel_matmul ()
15611599
15621600 # Must set distributed attr for Tensor Parallel !
15631601 self .weight .is_distributed = True if (vocab_size != config .vocab_size ) else False
@@ -1573,9 +1611,14 @@ def forward(self, hidden_states, tensor_parallel_output=None):
15731611 if self .weight .dtype != hidden_states .dtype :
15741612 hidden_states = paddle .cast (hidden_states , self .weight .dtype )
15751613
1576- logits = parallel_matmul (
1577- hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1578- )
1614+ if get_env_device () == "xpu" :
1615+ logits = self .xpu_parallel_matmul .forward (
1616+ hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1617+ )
1618+ else :
1619+ logits = parallel_matmul (
1620+ hidden_states , self .weight , transpose_y = self .transpose_y , tensor_parallel_output = tensor_parallel_output
1621+ )
15791622 return logits
15801623
15811624
0 commit comments