Skip to content

Commit 95a214a

Browse files
authored
support trainer_degree in name_mapping (#2935)
1 parent bce2c6c commit 95a214a

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def _init_model(self) -> nn.Layer:
6363
model.eval()
6464
return model
6565

66-
def get_name_mappings_to_training(self) -> Dict[str, str]:
66+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
6767
"""Get parameter name mappings between rollout and training models."""
68-
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})()
68+
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
6969

7070
def get_quantization_infer_keys(self) -> Dict[str, str]:
7171
"""Get parameter name mappings between rollout and training models."""
@@ -108,9 +108,6 @@ def _complete_missing_mappings(self) -> None:
108108
# Skip weight scale parameters in mapping. Train and infer have same key.
109109
self.infer_to_train_mapping[key] = key
110110

111-
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
112-
self.infer_to_train_mapping.pop("lm_head.linear.weight")
113-
114111
def get_quantization_infer_keys(self) -> list[str]:
115112
"""Get quantization infer keys"""
116113
quant_weight_key = []
@@ -143,7 +140,7 @@ def name(self) -> str:
143140
"""name"""
144141
return "Ernie4_5_MoeForCausalLMRL"
145142

146-
def get_name_mappings_to_training(self) -> Dict[str, str]:
143+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
147144
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
148145
# Prepare placeholders
149146
place_holders = ["weight"]
@@ -187,8 +184,7 @@ def _add_layer_mappings(layer_idx: int):
187184
assert isinstance(self.fd_config.model_config.moe_layer_start_index, int)
188185
# Process MoE layers
189186
for layer_idx in range(
190-
self.fd_config.model_config.moe_layer_start_index,
191-
self.fd_config.model_config.num_hidden_layers,
187+
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers
192188
):
193189
_add_layer_mappings(layer_idx)
194190

@@ -216,7 +212,7 @@ def name(self) -> str:
216212
"""name"""
217213
return "Ernie4_5_VLMoeForConditionalGenerationRL"
218214

219-
def get_name_mappings_to_training(self) -> Dict[str, str]:
215+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
220216
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
221217
# Prepare placeholders
222218
place_holders = ["weight"]
@@ -249,10 +245,7 @@ def _generate_ranges(start, end, step=16, take=8):
249245

250246
expert_mappings = defaultdict(list)
251247
for expert_idx in _generate_ranges(
252-
expert_start,
253-
total_moe_num,
254-
expert_num_per_rank * 2,
255-
expert_num_per_rank,
248+
expert_start, total_moe_num, expert_num_per_rank * 2, expert_num_per_rank
256249
):
257250
for ph in place_holders:
258251
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append(
@@ -284,9 +277,9 @@ def _generate_ranges(start, end, step=16, take=8):
284277

285278
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
286279
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
287-
rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size
288-
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // rollout_model_degree
289-
280+
if not trainer_degree:
281+
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
282+
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
290283
# Process MoE layers
291284
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
292285
_add_expert_mappings(layer_idx, "text", expert_start=0)
@@ -317,7 +310,7 @@ def name(self) -> str:
317310
"""name"""
318311
return "Qwen2ForCausalLMRL"
319312

320-
def get_name_mappings_to_training(self) -> Dict[str, str]:
313+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
321314
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
322315
# Prepare placeholders
323316
place_holders = ["weight"]
@@ -361,7 +354,7 @@ def name(self) -> str:
361354
"""name"""
362355
return "Qwen3MoeForCausalLMRL"
363356

364-
def get_name_mappings_to_training(self) -> Dict[str, str]:
357+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
365358
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
366359
# Prepare placeholders
367360
place_holders = ["weight"]
@@ -431,5 +424,5 @@ def name(self) -> str:
431424
"""name"""
432425
return "Qwen3ForCausalLMRL"
433426

434-
def get_name_mappings_to_training(self) -> Dict[str, str]:
427+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
435428
pass

test/ci_use/EB_VL_Lite/baseline.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,7 @@ ernie.layers.27.post_attention_layernorm.weight
10081008
ernie.norm.weight
10091009
lm_head.linear.weight
10101010
ernie.embed_tokens.embeddings.weight:ernie.embed_tokens.weight
1011+
lm_head.linear.weight:lm_head.weight
10111012
ernie.layers.1.mlp.text_fused_moe.gate_weight:ernie.layers.1.mlp.gate.weight
10121013
ernie.layers.1.mlp.text_fused_moe.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias
10131014
ernie.layers.1.mlp.text_fused_moe.up_gate_proj_weight:['ernie.layers.1.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.95.up_gate_proj.weight']

0 commit comments

Comments
 (0)