@@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer:
50
50
model .eval ()
51
51
return model
52
52
53
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
53
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
54
54
"""Get parameter name mappings between rollout and training models."""
55
55
return getattr (self .rollout_model , "get_name_mappings_to_training" , lambda : {})(trainer_degree )
56
56
@@ -125,7 +125,7 @@ def name(self) -> str:
125
125
"""name"""
126
126
return "Ernie4_5_MoeForCausalLMRL"
127
127
128
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
128
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
129
129
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
130
130
# Prepare placeholders
131
131
place_holders = ["weight" ]
@@ -285,7 +285,7 @@ def name(self) -> str:
285
285
"""name"""
286
286
return "Qwen2ForCausalLMRL"
287
287
288
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
288
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
289
289
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
290
290
# Prepare placeholders
291
291
place_holders = ["weight" ]
@@ -327,7 +327,7 @@ def name(self) -> str:
327
327
"""name"""
328
328
return "Qwen3MoeForCausalLMRL"
329
329
330
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
330
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
331
331
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
332
332
# Prepare placeholders
333
333
place_holders = ["weight" ]
@@ -394,5 +394,5 @@ def name(self) -> str:
394
394
"""name"""
395
395
return "Qwen3ForCausalLMRL"
396
396
397
- def get_name_mappings_to_training (self , trainer_degree = 1 ) -> Dict [str , str ]:
397
+ def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
398
398
pass
0 commit comments