@@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer:
50
50
model .eval ()
51
51
return model
52
52
53
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
53
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
54
54
"""Get parameter name mappings between rollout and training models."""
55
55
return getattr (self .rollout_model , "get_name_mappings_to_training" , lambda : {})(trainer_degree )
56
56
@@ -125,7 +125,7 @@ def name(self) -> str:
125
125
"""name"""
126
126
return "Ernie4_5_MoeForCausalLMRL"
127
127
128
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
128
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
129
129
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
130
130
# Prepare placeholders
131
131
place_holders = ["weight" ]
@@ -192,7 +192,7 @@ def name(self) -> str:
192
192
"""name"""
193
193
return "Ernie4_5_VLMoeForConditionalGenerationRL"
194
194
195
- def get_name_mappings_to_training (self , trainer_degree = 2 ) -> Dict [str , str ]:
195
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
196
196
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
197
197
# Prepare placeholders
198
198
place_holders = ["weight" ]
@@ -255,6 +255,8 @@ def _generate_ranges(start, end, step=16, take=8):
255
255
256
256
assert isinstance (self .fd_config .model_config .moe_num_experts , list )
257
257
total_moe_num = sum (self .fd_config .model_config .moe_num_experts )
258
+ if not trainer_degree :
259
+ trainer_degree = self .fd_config .parallel_config .tensor_parallel_size
258
260
expert_num_per_rank = self .fd_config .model_config .moe_num_experts [0 ] // trainer_degree
259
261
# Process MoE layers
260
262
for layer_idx in range (text_moe_layer_start_index , text_moe_layer_end_index ):
@@ -285,7 +287,7 @@ def name(self) -> str:
285
287
"""name"""
286
288
return "Qwen2ForCausalLMRL"
287
289
288
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
290
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
289
291
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
290
292
# Prepare placeholders
291
293
place_holders = ["weight" ]
@@ -327,7 +329,7 @@ def name(self) -> str:
327
329
"""name"""
328
330
return "Qwen3MoeForCausalLMRL"
329
331
330
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
332
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
331
333
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
332
334
# Prepare placeholders
333
335
place_holders = ["weight" ]
@@ -394,5 +396,5 @@ def name(self) -> str:
394
396
"""name"""
395
397
return "Qwen3ForCausalLMRL"
396
398
397
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
399
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
398
400
pass
0 commit comments