1
1
#
2
2
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
- # Adapted from vllm/model_executor/models/qwen2_5_vl.py
4
3
# Copyright 2023 The vLLM team.
5
4
#
6
5
# This file is a part of the vllm-ascend project.
27
26
from einops import rearrange
28
27
from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
29
28
Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
29
+
30
+ try :
31
+ from transformers .models .qwen3_vl .configuration_qwen3_vl import \
32
+ Qwen3VLConfig
33
+ from transformers .models .qwen3_vl_moe .configuration_qwen3_vl_moe import \
34
+ Qwen3VLMoeConfig
35
+ except ImportError :
36
+ pass
30
37
from vllm .config import VllmConfig
31
38
from vllm .distributed import parallel_state
32
39
from vllm .distributed import utils as dist_utils
33
- from vllm .model_executor .layers .activation import get_act_and_mul_fn
40
+ from vllm .model_executor .layers .activation import (_ACTIVATION_REGISTRY ,
41
+ get_act_and_mul_fn )
34
42
from vllm .model_executor .layers .layernorm import RMSNorm
35
43
from vllm .model_executor .layers .quantization import QuantizationConfig
36
44
from vllm .model_executor .models .qwen2_5_vl import (
37
45
Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
38
46
Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
39
47
Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
40
48
Qwen2_5_VLProcessingInfo )
41
- from vllm .model_executor .models .utils import maybe_prefix
49
+
50
+ try :
51
+ from vllm .model_executor .models .qwen3_vl import (
52
+ Qwen3_VisionBlock , Qwen3_VisionPatchEmbed , Qwen3_VisionTransformer ,
53
+ Qwen3VLDummyInputsBuilder , Qwen3VLForConditionalGeneration ,
54
+ Qwen3VLMultiModalProcessor , Qwen3VLProcessingInfo )
55
+ from vllm .model_executor .models .qwen3_vl_moe import (
56
+ Qwen3VLMoeForConditionalGeneration , Qwen3VLMoeProcessingInfo )
57
+ except ImportError :
58
+ Qwen3_VisionBlock = object
59
+ Qwen3_VisionPatchEmbed = object
60
+ Qwen3_VisionTransformer = object
61
+ Qwen3VLDummyInputsBuilder = object
62
+ Qwen3VLForConditionalGeneration = object
63
+ Qwen3VLMultiModalProcessor = object
64
+ Qwen3VLProcessingInfo = object
65
+ Qwen3VLMoeForConditionalGeneration = object
66
+ Qwen3VLMoeProcessingInfo = object
67
+ from vllm .model_executor .models .utils import WeightsMapper , maybe_prefix
42
68
from vllm .multimodal import MULTIMODAL_REGISTRY
43
69
44
70
from vllm_ascend .models .qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
@@ -112,16 +138,14 @@ def forward(
112
138
113
139
class AscendQwen2_5_VisionBlock_Without_Padding (Qwen2_5_VisionBlock ):
114
140
115
- def __init__ (
116
- self ,
117
- dim : int ,
118
- num_heads : int ,
119
- mlp_hidden_dim : int ,
120
- act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
121
- norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
122
- quant_config : Optional [QuantizationConfig ] = None ,
123
- prefix : str = "" ,
124
- ) -> None :
141
+ def __init__ (self ,
142
+ dim : int ,
143
+ num_heads : int ,
144
+ mlp_hidden_dim : int ,
145
+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
146
+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
147
+ quant_config : Optional [QuantizationConfig ] = None ,
148
+ prefix : str = "" ) -> None :
125
149
super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
126
150
quant_config , prefix )
127
151
self .attn = AscendQwen2_5_VisionAttention_Without_Padding (
@@ -321,6 +345,133 @@ def forward(
321
345
return x
322
346
323
347
348
+ class AscendQwen3_VisionPatchEmbed (Qwen3_VisionPatchEmbed ):
349
+
350
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
351
+ x = x .matmul (
352
+ self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
353
+ x = x + self .proj .bias
354
+ return x
355
+
356
+
357
+ class AscendQwen3_VisionBlock (Qwen3_VisionBlock ):
358
+
359
+ def __init__ (
360
+ self ,
361
+ dim : int ,
362
+ num_heads : int ,
363
+ mlp_hidden_dim : int ,
364
+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
365
+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
366
+ quant_config : Optional [QuantizationConfig ] = None ,
367
+ prefix : str = "" ,
368
+ use_data_parallel : bool = False ,
369
+ ) -> None :
370
+ super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
371
+ quant_config , prefix , use_data_parallel )
372
+ self .attn = AscendQwen2_5_VisionAttention_Without_Padding (
373
+ embed_dim = dim ,
374
+ num_heads = num_heads ,
375
+ projection_size = dim ,
376
+ quant_config = quant_config ,
377
+ prefix = f"{ prefix } .attn" )
378
+
379
+ def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
380
+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
381
+ x = x + self .attn (
382
+ self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
383
+
384
+ x = x + self .mlp (self .norm2 (x ))
385
+ return x
386
+
387
+
388
+ class AscendQwen3_VisionTransformer (Qwen3_VisionTransformer ):
389
+
390
+ def __init__ (
391
+ self ,
392
+ vision_config ,
393
+ norm_eps : float = 1e-6 ,
394
+ quant_config : Optional [QuantizationConfig ] = None ,
395
+ prefix : str = "" ,
396
+ use_data_parallel : bool = False ,
397
+ ) -> None :
398
+ super ().__init__ (vision_config , norm_eps , quant_config , prefix ,
399
+ use_data_parallel )
400
+ norm_layer = partial (nn .LayerNorm , eps = norm_eps )
401
+ self .patch_embed = AscendQwen3_VisionPatchEmbed (
402
+ patch_size = self .patch_size ,
403
+ temporal_patch_size = self .temporal_patch_size ,
404
+ in_channels = vision_config .in_channels ,
405
+ hidden_size = self .hidden_size ,
406
+ )
407
+ self .blocks = nn .ModuleList ([
408
+ AscendQwen3_VisionBlock (
409
+ dim = self .hidden_size ,
410
+ num_heads = self .num_heads ,
411
+ mlp_hidden_dim = vision_config .intermediate_size ,
412
+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
413
+ norm_layer = norm_layer ,
414
+ quant_config = quant_config ,
415
+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
416
+ for layer_idx in range (vision_config .depth )
417
+ ])
418
+ self .hidden_size_per_attention_head = dist_utils .divide (
419
+ self .hidden_size , self .num_heads )
420
+
421
+ def cal_cos_sin (self , rotary_pos_emb ):
422
+ cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
423
+ sin = rotary_pos_emb .sin ()
424
+ cos_new = torch .cat ((cos , cos ), dim = - 1 )
425
+ sin_new = torch .cat ((sin , sin ), dim = - 1 )
426
+ cos_new = cos_new .reshape (1 , - 1 , 1 ,
427
+ self .hidden_size_per_attention_head )
428
+ sin_new = sin_new .reshape (1 , - 1 , 1 ,
429
+ self .hidden_size_per_attention_head )
430
+ return cos_new , sin_new
431
+
432
+ def forward (
433
+ self ,
434
+ x : torch .Tensor ,
435
+ grid_thw : list [list [int ]],
436
+ ) -> torch .Tensor :
437
+ hidden_states = x .to (device = self .device , dtype = self .dtype )
438
+ hidden_states = self .patch_embed (hidden_states )
439
+
440
+ pos_embeds = self .fast_pos_embed_interpolate (grid_thw )
441
+ hidden_states = hidden_states + pos_embeds
442
+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
443
+ grid_thw_tensor = torch .tensor (grid_thw ,
444
+ device = self .device ,
445
+ dtype = torch .int32 )
446
+ cu_seqlens = torch .repeat_interleave (
447
+ grid_thw_tensor [:, 1 ] * grid_thw_tensor [:, 2 ],
448
+ grid_thw_tensor [:, 0 ]).cpu ().to (torch .int32 )
449
+ cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
450
+
451
+ hidden_states = hidden_states .unsqueeze (1 )
452
+ rotary_pos_emb = rotary_pos_emb .to (hidden_states .device )
453
+
454
+ cos , sin = self .cal_cos_sin (rotary_pos_emb )
455
+
456
+ deepstack_feature_lists = []
457
+ for layer_num , blk in enumerate (self .blocks ):
458
+ hidden_states = blk (hidden_states ,
459
+ cu_seqlens = cu_seqlens ,
460
+ cos = cos ,
461
+ sin = sin )
462
+ if layer_num in self .deepstack_visual_indexes :
463
+ deepstack_merger_idx = self .deepstack_visual_indexes .index (
464
+ layer_num )
465
+ deepstack_feature = self .deepstack_merger_list [
466
+ deepstack_merger_idx ](hidden_states )
467
+ deepstack_feature_lists .append (deepstack_feature )
468
+ hidden_states = self .merger (hidden_states )
469
+ hidden_states = torch .cat (
470
+ [hidden_states ] + deepstack_feature_lists ,
471
+ dim = 1 ) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
472
+ return hidden_states
473
+
474
+
324
475
@MULTIMODAL_REGISTRY .register_processor (
325
476
Qwen2_5_VLMultiModalProcessor ,
326
477
info = Qwen2_5_VLProcessingInfo ,
@@ -371,3 +522,85 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
371
522
merge_size = self .visual .spatial_merge_size
372
523
sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
373
524
return video_embeds .split (sizes .tolist ())
525
+
526
+
527
+ @MULTIMODAL_REGISTRY .register_processor (Qwen3VLMultiModalProcessor ,
528
+ info = Qwen3VLProcessingInfo ,
529
+ dummy_inputs = Qwen3VLDummyInputsBuilder )
530
+ class AscendQwen3VLForConditionalGeneration (Qwen3VLForConditionalGeneration ):
531
+ packed_modules_mapping = {
532
+ "qkv_proj" : [
533
+ "q_proj" ,
534
+ "k_proj" ,
535
+ "v_proj" ,
536
+ ],
537
+ "gate_up_proj" : [
538
+ "gate_proj" ,
539
+ "up_proj" ,
540
+ ],
541
+ }
542
+
543
+ supports_encoder_tp_data = True
544
+
545
+ # To ensure correct weight loading and mapping.
546
+ hf_to_vllm_mapper = WeightsMapper (
547
+ orig_to_new_prefix = {
548
+ "model.visual." : "visual." ,
549
+ "lm_head." : "language_model.lm_head." ,
550
+ "model.language_model." : "language_model.model." ,
551
+ })
552
+
553
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
554
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
555
+ config : Qwen3VLConfig = vllm_config .model_config .hf_config
556
+ quant_config = vllm_config .quant_config
557
+ self .visual = AscendQwen3_VisionTransformer (
558
+ config .vision_config ,
559
+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
560
+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
561
+ prefix = maybe_prefix (prefix , "visual" ),
562
+ use_data_parallel = self .use_data_parallel )
563
+
564
+
565
+ @MULTIMODAL_REGISTRY .register_processor (Qwen3VLMultiModalProcessor ,
566
+ info = Qwen3VLMoeProcessingInfo ,
567
+ dummy_inputs = Qwen3VLDummyInputsBuilder )
568
+ class AscendQwen3VLMoeForConditionalGeneration (
569
+ Qwen3VLMoeForConditionalGeneration ):
570
+ packed_modules_mapping = {
571
+ "qkv_proj" : [
572
+ "q_proj" ,
573
+ "k_proj" ,
574
+ "v_proj" ,
575
+ ],
576
+ "gate_up_proj" : [
577
+ "gate_proj" ,
578
+ "up_proj" ,
579
+ ],
580
+ }
581
+
582
+ supports_encoder_tp_data = True
583
+
584
+ # To ensure correct weight loading and mapping.
585
+ hf_to_vllm_mapper = WeightsMapper (
586
+ orig_to_new_prefix = {
587
+ "model.visual." : "visual." ,
588
+ "lm_head." : "language_model.lm_head." ,
589
+ "model.language_model." : "language_model.model." ,
590
+ })
591
+
592
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
593
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
594
+ config : Qwen3VLMoeConfig = vllm_config .model_config .hf_config
595
+ quant_config = vllm_config .quant_config
596
+ multimodal_config = vllm_config .model_config .multimodal_config
597
+ self .multimodal_config = multimodal_config
598
+ self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
599
+
600
+ self .visual = AscendQwen3_VisionTransformer (
601
+ config .vision_config ,
602
+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
603
+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
604
+ prefix = maybe_prefix (prefix , "visual" ),
605
+ use_data_parallel = self .use_data_parallel ,
606
+ )
0 commit comments