Skip to content

Commit 7e92f13

Browse files
committed
fix
1 parent e82fd62 commit 7e92f13

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer:
5050
model.eval()
5151
return model
5252

53-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
53+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
5454
"""Get parameter name mappings between rollout and training models."""
5555
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
5656

@@ -125,7 +125,7 @@ def name(self) -> str:
125125
"""name"""
126126
return "Ernie4_5_MoeForCausalLMRL"
127127

128-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
128+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
129129
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
130130
# Prepare placeholders
131131
place_holders = ["weight"]
@@ -285,7 +285,7 @@ def name(self) -> str:
285285
"""name"""
286286
return "Qwen2ForCausalLMRL"
287287

288-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
288+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
289289
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
290290
# Prepare placeholders
291291
place_holders = ["weight"]
@@ -327,7 +327,7 @@ def name(self) -> str:
327327
"""name"""
328328
return "Qwen3MoeForCausalLMRL"
329329

330-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
330+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
331331
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
332332
# Prepare placeholders
333333
place_holders = ["weight"]
@@ -394,5 +394,5 @@ def name(self) -> str:
394394
"""name"""
395395
return "Qwen3ForCausalLMRL"
396396

397-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
397+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
398398
pass

0 commit comments

Comments
 (0)