Skip to content

Commit 31dda3f

Browse files
authored
[Model]Add support for qwen3_vl and qwen3_vl_moe (vllm-project#3103)
### What this PR does / why we need it? This PR is for the adaptation and optimization of qwen3_vl and qwen3_vl_moe on the Ascend platform. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@b106890 --------- Signed-off-by: booker123456 <945658361@qq.com>
1 parent f7a3815 commit 31dda3f

File tree

2 files changed

+256
-13
lines changed

2 files changed

+256
-13
lines changed

vllm_ascend/models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ def register_model():
88
"Qwen2VLForConditionalGeneration",
99
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
1010

11+
ModelRegistry.register_model(
12+
"Qwen3VLMoeForConditionalGeneration",
13+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration"
14+
)
15+
16+
ModelRegistry.register_model(
17+
"Qwen3VLForConditionalGeneration",
18+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration"
19+
)
20+
1121
if envs_ascend.USE_OPTIMIZED_MODEL:
1222
ModelRegistry.register_model(
1323
"Qwen2_5_VLForConditionalGeneration",

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 246 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#
22
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3-
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
43
# Copyright 2023 The vLLM team.
54
#
65
# This file is a part of the vllm-ascend project.
@@ -27,18 +26,45 @@
2726
from einops import rearrange
2827
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
2928
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
29+
30+
try:
31+
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
32+
Qwen3VLConfig
33+
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
34+
Qwen3VLMoeConfig
35+
except ImportError:
36+
pass
3037
from vllm.config import VllmConfig
3138
from vllm.distributed import parallel_state
3239
from vllm.distributed import utils as dist_utils
33-
from vllm.model_executor.layers.activation import get_act_and_mul_fn
40+
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
41+
get_act_and_mul_fn)
3442
from vllm.model_executor.layers.layernorm import RMSNorm
3543
from vllm.model_executor.layers.quantization import QuantizationConfig
3644
from vllm.model_executor.models.qwen2_5_vl import (
3745
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
3846
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
3947
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
4048
Qwen2_5_VLProcessingInfo)
41-
from vllm.model_executor.models.utils import maybe_prefix
49+
50+
try:
51+
from vllm.model_executor.models.qwen3_vl import (
52+
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
53+
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
54+
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
55+
from vllm.model_executor.models.qwen3_vl_moe import (
56+
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
57+
except ImportError:
58+
Qwen3_VisionBlock = object
59+
Qwen3_VisionPatchEmbed = object
60+
Qwen3_VisionTransformer = object
61+
Qwen3VLDummyInputsBuilder = object
62+
Qwen3VLForConditionalGeneration = object
63+
Qwen3VLMultiModalProcessor = object
64+
Qwen3VLProcessingInfo = object
65+
Qwen3VLMoeForConditionalGeneration = object
66+
Qwen3VLMoeProcessingInfo = object
67+
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
4268
from vllm.multimodal import MULTIMODAL_REGISTRY
4369

4470
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
@@ -112,16 +138,14 @@ def forward(
112138

113139
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
114140

115-
def __init__(
116-
self,
117-
dim: int,
118-
num_heads: int,
119-
mlp_hidden_dim: int,
120-
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
121-
norm_layer: Optional[Callable[[int], nn.Module]] = None,
122-
quant_config: Optional[QuantizationConfig] = None,
123-
prefix: str = "",
124-
) -> None:
141+
def __init__(self,
142+
dim: int,
143+
num_heads: int,
144+
mlp_hidden_dim: int,
145+
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
146+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
147+
quant_config: Optional[QuantizationConfig] = None,
148+
prefix: str = "") -> None:
125149
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
126150
quant_config, prefix)
127151
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
@@ -321,6 +345,133 @@ def forward(
321345
return x
322346

323347

348+
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
349+
350+
def forward(self, x: torch.Tensor) -> torch.Tensor:
351+
x = x.matmul(
352+
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
353+
x = x + self.proj.bias
354+
return x
355+
356+
357+
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
358+
359+
def __init__(
360+
self,
361+
dim: int,
362+
num_heads: int,
363+
mlp_hidden_dim: int,
364+
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
365+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
366+
quant_config: Optional[QuantizationConfig] = None,
367+
prefix: str = "",
368+
use_data_parallel: bool = False,
369+
) -> None:
370+
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
371+
quant_config, prefix, use_data_parallel)
372+
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
373+
embed_dim=dim,
374+
num_heads=num_heads,
375+
projection_size=dim,
376+
quant_config=quant_config,
377+
prefix=f"{prefix}.attn")
378+
379+
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
380+
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
381+
x = x + self.attn(
382+
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
383+
384+
x = x + self.mlp(self.norm2(x))
385+
return x
386+
387+
388+
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
389+
390+
def __init__(
391+
self,
392+
vision_config,
393+
norm_eps: float = 1e-6,
394+
quant_config: Optional[QuantizationConfig] = None,
395+
prefix: str = "",
396+
use_data_parallel: bool = False,
397+
) -> None:
398+
super().__init__(vision_config, norm_eps, quant_config, prefix,
399+
use_data_parallel)
400+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
401+
self.patch_embed = AscendQwen3_VisionPatchEmbed(
402+
patch_size=self.patch_size,
403+
temporal_patch_size=self.temporal_patch_size,
404+
in_channels=vision_config.in_channels,
405+
hidden_size=self.hidden_size,
406+
)
407+
self.blocks = nn.ModuleList([
408+
AscendQwen3_VisionBlock(
409+
dim=self.hidden_size,
410+
num_heads=self.num_heads,
411+
mlp_hidden_dim=vision_config.intermediate_size,
412+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
413+
norm_layer=norm_layer,
414+
quant_config=quant_config,
415+
prefix=f"{prefix}.blocks.{layer_idx}")
416+
for layer_idx in range(vision_config.depth)
417+
])
418+
self.hidden_size_per_attention_head = dist_utils.divide(
419+
self.hidden_size, self.num_heads)
420+
421+
def cal_cos_sin(self, rotary_pos_emb):
422+
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
423+
sin = rotary_pos_emb.sin()
424+
cos_new = torch.cat((cos, cos), dim=-1)
425+
sin_new = torch.cat((sin, sin), dim=-1)
426+
cos_new = cos_new.reshape(1, -1, 1,
427+
self.hidden_size_per_attention_head)
428+
sin_new = sin_new.reshape(1, -1, 1,
429+
self.hidden_size_per_attention_head)
430+
return cos_new, sin_new
431+
432+
def forward(
433+
self,
434+
x: torch.Tensor,
435+
grid_thw: list[list[int]],
436+
) -> torch.Tensor:
437+
hidden_states = x.to(device=self.device, dtype=self.dtype)
438+
hidden_states = self.patch_embed(hidden_states)
439+
440+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
441+
hidden_states = hidden_states + pos_embeds
442+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
443+
grid_thw_tensor = torch.tensor(grid_thw,
444+
device=self.device,
445+
dtype=torch.int32)
446+
cu_seqlens = torch.repeat_interleave(
447+
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
448+
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
449+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
450+
451+
hidden_states = hidden_states.unsqueeze(1)
452+
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
453+
454+
cos, sin = self.cal_cos_sin(rotary_pos_emb)
455+
456+
deepstack_feature_lists = []
457+
for layer_num, blk in enumerate(self.blocks):
458+
hidden_states = blk(hidden_states,
459+
cu_seqlens=cu_seqlens,
460+
cos=cos,
461+
sin=sin)
462+
if layer_num in self.deepstack_visual_indexes:
463+
deepstack_merger_idx = self.deepstack_visual_indexes.index(
464+
layer_num)
465+
deepstack_feature = self.deepstack_merger_list[
466+
deepstack_merger_idx](hidden_states)
467+
deepstack_feature_lists.append(deepstack_feature)
468+
hidden_states = self.merger(hidden_states)
469+
hidden_states = torch.cat(
470+
[hidden_states] + deepstack_feature_lists,
471+
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
472+
return hidden_states
473+
474+
324475
@MULTIMODAL_REGISTRY.register_processor(
325476
Qwen2_5_VLMultiModalProcessor,
326477
info=Qwen2_5_VLProcessingInfo,
@@ -371,3 +522,85 @@ def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
371522
merge_size = self.visual.spatial_merge_size
372523
sizes = grid_thw.prod(-1) // merge_size // merge_size
373524
return video_embeds.split(sizes.tolist())
525+
526+
527+
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
528+
info=Qwen3VLProcessingInfo,
529+
dummy_inputs=Qwen3VLDummyInputsBuilder)
530+
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
531+
packed_modules_mapping = {
532+
"qkv_proj": [
533+
"q_proj",
534+
"k_proj",
535+
"v_proj",
536+
],
537+
"gate_up_proj": [
538+
"gate_proj",
539+
"up_proj",
540+
],
541+
}
542+
543+
supports_encoder_tp_data = True
544+
545+
# To ensure correct weight loading and mapping.
546+
hf_to_vllm_mapper = WeightsMapper(
547+
orig_to_new_prefix={
548+
"model.visual.": "visual.",
549+
"lm_head.": "language_model.lm_head.",
550+
"model.language_model.": "language_model.model.",
551+
})
552+
553+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
554+
super().__init__(vllm_config=vllm_config, prefix=prefix)
555+
config: Qwen3VLConfig = vllm_config.model_config.hf_config
556+
quant_config = vllm_config.quant_config
557+
self.visual = AscendQwen3_VisionTransformer(
558+
config.vision_config,
559+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
560+
quant_config=self._maybe_ignore_quant_config(quant_config),
561+
prefix=maybe_prefix(prefix, "visual"),
562+
use_data_parallel=self.use_data_parallel)
563+
564+
565+
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
566+
info=Qwen3VLMoeProcessingInfo,
567+
dummy_inputs=Qwen3VLDummyInputsBuilder)
568+
class AscendQwen3VLMoeForConditionalGeneration(
569+
Qwen3VLMoeForConditionalGeneration):
570+
packed_modules_mapping = {
571+
"qkv_proj": [
572+
"q_proj",
573+
"k_proj",
574+
"v_proj",
575+
],
576+
"gate_up_proj": [
577+
"gate_proj",
578+
"up_proj",
579+
],
580+
}
581+
582+
supports_encoder_tp_data = True
583+
584+
# To ensure correct weight loading and mapping.
585+
hf_to_vllm_mapper = WeightsMapper(
586+
orig_to_new_prefix={
587+
"model.visual.": "visual.",
588+
"lm_head.": "language_model.lm_head.",
589+
"model.language_model.": "language_model.model.",
590+
})
591+
592+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
593+
super().__init__(vllm_config=vllm_config, prefix=prefix)
594+
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
595+
quant_config = vllm_config.quant_config
596+
multimodal_config = vllm_config.model_config.multimodal_config
597+
self.multimodal_config = multimodal_config
598+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
599+
600+
self.visual = AscendQwen3_VisionTransformer(
601+
config.vision_config,
602+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
603+
quant_config=self._maybe_ignore_quant_config(quant_config),
604+
prefix=maybe_prefix(prefix, "visual"),
605+
use_data_parallel=self.use_data_parallel,
606+
)

0 commit comments

Comments
 (0)