Skip to content

Commit 30a5b0b

Browse files
committed
fix
1 parent e82fd62 commit 30a5b0b

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer:
5050
model.eval()
5151
return model
5252

53-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
53+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
5454
"""Get parameter name mappings between rollout and training models."""
5555
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
5656

@@ -125,7 +125,7 @@ def name(self) -> str:
125125
"""name"""
126126
return "Ernie4_5_MoeForCausalLMRL"
127127

128-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
128+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
129129
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
130130
# Prepare placeholders
131131
place_holders = ["weight"]
@@ -192,7 +192,7 @@ def name(self) -> str:
192192
"""name"""
193193
return "Ernie4_5_VLMoeForConditionalGenerationRL"
194194

195-
def get_name_mappings_to_training(self, trainer_degree=2) -> Dict[str, str]:
195+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
196196
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
197197
# Prepare placeholders
198198
place_holders = ["weight"]
@@ -255,6 +255,8 @@ def _generate_ranges(start, end, step=16, take=8):
255255

256256
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
257257
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
258+
if not trainer_degree:
259+
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
258260
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
259261
# Process MoE layers
260262
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
@@ -285,7 +287,7 @@ def name(self) -> str:
285287
"""name"""
286288
return "Qwen2ForCausalLMRL"
287289

288-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
290+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
289291
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
290292
# Prepare placeholders
291293
place_holders = ["weight"]
@@ -327,7 +329,7 @@ def name(self) -> str:
327329
"""name"""
328330
return "Qwen3MoeForCausalLMRL"
329331

330-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
332+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
331333
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
332334
# Prepare placeholders
333335
place_holders = ["weight"]
@@ -394,5 +396,5 @@ def name(self) -> str:
394396
"""name"""
395397
return "Qwen3ForCausalLMRL"
396398

397-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
399+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
398400
pass

0 commit comments

Comments
 (0)