Skip to content

Commit 4ead158

Browse files
authored
【Sync develop】support vl model name_mapping and ori_vocab_size (#2915)
* support vl ori_vacab_size * support trainer_degree in name_mapping * fix
1 parent f941124 commit 4ead158

File tree

6 files changed

+167
-134
lines changed

6 files changed

+167
-134
lines changed

fastdeploy/config.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ class MoEPhase(Enum):
3737
PREFILL = 1
3838
DECODER = 2
3939

40+
class ErnieArchitectures:
41+
"""Helper class for ERNIE architecture check."""
42+
43+
ARCHITECTURES = {
44+
"Ernie4_5_ForCausalLM",
45+
"Ernie4_5_MoeForCausalLM",
46+
"Ernie4_5_VLMoeForConditionalGeneration"
47+
}
48+
49+
@classmethod
50+
def contains_ernie_arch(cls, architectures):
51+
"""Check if any ERNIE architecture is present in the given architectures."""
52+
return any(arch in architectures for arch in cls.ARCHITECTURES)
53+
54+
@classmethod
55+
def is_ernie_arch(cls, architecture):
56+
"""Check if the given architecture is an ERNIE architecture."""
57+
return architecture in cls.ARCHITECTURES
58+
4059
PRETRAINED_INIT_CONFIGURATION = {
4160
"rope_theta" : 10000.0,
4261
"num_key_value_heads" : -1,
@@ -108,9 +127,10 @@ def __init__(
108127
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
109128

110129
self.ori_vocab_size = self.vocab_size
111-
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
130+
if ErnieArchitectures.contains_ernie_arch(self.architectures):
112131
self.ori_vocab_size = args["ori_vocab_size"]
113132

133+
114134
class ParallelConfig:
115135
"""Configuration for the distributed execution."""
116136
def __init__(

fastdeploy/input/preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from fastdeploy.engine.config import ModelConfig
1919
from fastdeploy.reasoning import ReasoningParserManager
20+
from fastdeploy.config import ErnieArchitectures
2021

2122

2223
class InputPreprocessor:
@@ -71,8 +72,7 @@ def create_processor(self):
7172
self.reasoning_parser)
7273
architectures = ModelConfig(self.model_name_or_path).architectures
7374
if not self.enable_mm:
74-
if "Ernie4_5_MoeForCausalLM" not in architectures \
75-
and "Ernie4_5_ForCausalLM" not in architectures:
75+
if not ErnieArchitectures.contains_ernie_arch(architectures):
7676
from fastdeploy.input.text_processor import DataProcessor
7777
self.processor = DataProcessor(
7878
model_name_or_path=self.model_name_or_path, reasoning_parser_obj=reasoning_parser_obj)

fastdeploy/model_executor/guided_decoding/base_guided_decoding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
from concurrent.futures import ThreadPoolExecutor
1919

20-
from fastdeploy.config import FDConfig
20+
from fastdeploy.config import FDConfig, ErnieArchitectures
2121
from fastdeploy.engine.request import Request
2222
from fastdeploy.utils import llm_logger
2323

@@ -268,8 +268,7 @@ def _get_tokenizer_hf(self):
268268
"""
269269
try:
270270
architectures = self.fd_config.model_config.architectures
271-
if "Ernie4_5_MoeForCausalLM" not in architectures \
272-
and "Ernie4_5_ForCausalLM" not in architectures:
271+
if not ErnieArchitectures.contains_ernie_arch(architectures):
273272

274273
from transformers import AutoTokenizer, PreTrainedTokenizerFast
275274
tokenizer = AutoTokenizer.from_pretrained(

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int,
161161

162162
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
163163
if self.num_shared_experts > 0:
164-
self.share_experts = Ernie4_5_VLMLP(
164+
self.shared_experts = Ernie4_5_VLMLP(
165165
fd_config=fd_config,
166166
intermediate_size=self.num_shared_experts *
167167
fd_config.model_config.moe_intermediate_size[0],
@@ -193,11 +193,11 @@ def load_state_dict(self, state_dict):
193193
if self.text_fused_moe.moe_use_gate_correction_bias:
194194
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
195195
if self.num_shared_experts > 0:
196-
self.share_experts.load_state_dict(state_dict)
196+
self.shared_experts.load_state_dict(state_dict)
197197

198198
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
199199
if self.num_shared_experts > 0:
200-
share_experts_out = self.share_experts(hidden_states)
200+
shared_experts_out = self.shared_experts(hidden_states)
201201
if vl_moe_meta.image_input is not None:
202202
text_image_gather_scatter(
203203
hidden_states,
@@ -222,7 +222,7 @@ def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
222222
else:
223223
hidden_states = self.text_fused_moe(hidden_states)
224224
if self.num_shared_experts > 0:
225-
hidden_states += share_experts_out
225+
hidden_states += shared_experts_out
226226
if self.tp_size > 1:
227227
tensor_model_parallel_all_reduce(hidden_states)
228228
return hidden_states
@@ -759,4 +759,4 @@ def get_vison_parallel_split_mappings(num_layers: int):
759759
config.vision_config.get("depth")
760760
)
761761

762-
return {**mappings, **vision_mappings}
762+
return {**mappings, **vision_mappings}

0 commit comments

Comments
 (0)