@@ -63,9 +63,9 @@ def _init_model(self) -> nn.Layer:
63
63
model .eval ()
64
64
return model
65
65
66
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
66
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
67
67
"""Get parameter name mappings between rollout and training models."""
68
- return getattr (self .rollout_model , "get_name_mappings_to_training" , lambda : {})()
68
+ return getattr (self .rollout_model , "get_name_mappings_to_training" , lambda : {})(trainer_degree )
69
69
70
70
def get_quantization_infer_keys (self ) -> Dict [str , str ]:
71
71
"""Get parameter name mappings between rollout and training models."""
@@ -108,9 +108,6 @@ def _complete_missing_mappings(self) -> None:
108
108
# Skip weight scale parameters in mapping. Train and infer have same key.
109
109
self .infer_to_train_mapping [key ] = key
110
110
111
- if getattr (self .fd_config .model_config , "tie_word_embeddings" , False ):
112
- self .infer_to_train_mapping .pop ("lm_head.linear.weight" )
113
-
114
111
def get_quantization_infer_keys (self ) -> list [str ]:
115
112
"""Get quantization infer keys"""
116
113
quant_weight_key = []
@@ -143,7 +140,7 @@ def name(self) -> str:
143
140
"""name"""
144
141
return "Ernie4_5_MoeForCausalLMRL"
145
142
146
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
143
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
147
144
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
148
145
# Prepare placeholders
149
146
place_holders = ["weight" ]
@@ -187,8 +184,7 @@ def _add_layer_mappings(layer_idx: int):
187
184
assert isinstance (self .fd_config .model_config .moe_layer_start_index , int )
188
185
# Process MoE layers
189
186
for layer_idx in range (
190
- self .fd_config .model_config .moe_layer_start_index ,
191
- self .fd_config .model_config .num_hidden_layers ,
187
+ self .fd_config .model_config .moe_layer_start_index , self .fd_config .model_config .num_hidden_layers
192
188
):
193
189
_add_layer_mappings (layer_idx )
194
190
@@ -216,7 +212,7 @@ def name(self) -> str:
216
212
"""name"""
217
213
return "Ernie4_5_VLMoeForConditionalGenerationRL"
218
214
219
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
215
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
220
216
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
221
217
# Prepare placeholders
222
218
place_holders = ["weight" ]
@@ -249,10 +245,7 @@ def _generate_ranges(start, end, step=16, take=8):
249
245
250
246
expert_mappings = defaultdict (list )
251
247
for expert_idx in _generate_ranges (
252
- expert_start ,
253
- total_moe_num ,
254
- expert_num_per_rank * 2 ,
255
- expert_num_per_rank ,
248
+ expert_start , total_moe_num , expert_num_per_rank * 2 , expert_num_per_rank
256
249
):
257
250
for ph in place_holders :
258
251
expert_mappings [f"{ base_name } .{ layer_idx } .mlp.{ moe_tag } _fused_moe.up_gate_proj_weight" ].append (
@@ -284,9 +277,9 @@ def _generate_ranges(start, end, step=16, take=8):
284
277
285
278
assert isinstance (self .fd_config .model_config .moe_num_experts , list )
286
279
total_moe_num = sum (self .fd_config .model_config .moe_num_experts )
287
- rollout_model_degree = self . fd_config . parallel_config . tensor_parallel_size
288
- expert_num_per_rank = self .fd_config .model_config . moe_num_experts [ 0 ] // rollout_model_degree
289
-
280
+ if not trainer_degree :
281
+ trainer_degree = self .fd_config .parallel_config . tensor_parallel_size
282
+ expert_num_per_rank = self . fd_config . model_config . moe_num_experts [ 0 ] // trainer_degree
290
283
# Process MoE layers
291
284
for layer_idx in range (text_moe_layer_start_index , text_moe_layer_end_index ):
292
285
_add_expert_mappings (layer_idx , "text" , expert_start = 0 )
@@ -317,7 +310,7 @@ def name(self) -> str:
317
310
"""name"""
318
311
return "Qwen2ForCausalLMRL"
319
312
320
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
313
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
321
314
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
322
315
# Prepare placeholders
323
316
place_holders = ["weight" ]
@@ -361,7 +354,7 @@ def name(self) -> str:
361
354
"""name"""
362
355
return "Qwen3MoeForCausalLMRL"
363
356
364
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
357
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
365
358
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
366
359
# Prepare placeholders
367
360
place_holders = ["weight" ]
@@ -431,5 +424,5 @@ def name(self) -> str:
431
424
"""name"""
432
425
return "Qwen3ForCausalLMRL"
433
426
434
- def get_name_mappings_to_training (self ) -> Dict [str , str ]:
427
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
435
428
pass
0 commit comments