40
40
41
41
42
42
class AscendQwen3_VisionPatchEmbed (Qwen3_VisionPatchEmbed ):
43
-
43
+
44
44
def forward (self , x : torch .Tensor ) -> torch .Tensor :
45
45
x = x .matmul (
46
46
self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
@@ -71,14 +71,16 @@ def __init__(
71
71
use_data_parallel = use_data_parallel )
72
72
73
73
def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
74
- cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
74
+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
75
75
x = x + self .attn (
76
76
self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
77
77
78
78
x = x + self .mlp (self .norm2 (x ))
79
79
return x
80
80
81
+
81
82
class AscendQwen3_VisionTransformer (Qwen3_VisionTransformer ):
83
+
82
84
def __init__ (
83
85
self ,
84
86
vision_config : Qwen3VLVisionConfig ,
@@ -198,4 +200,4 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198
200
norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
199
201
quant_config = self ._maybe_ignore_quant_config (quant_config ),
200
202
prefix = maybe_prefix (prefix , "visual" ),
201
- use_data_parallel = self .use_data_parallel )
203
+ use_data_parallel = self .use_data_parallel )
0 commit comments