Skip to content

Commit 529b221

Browse files
committed
fix
1 parent e82fd62 commit 529b221

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer:
5050
model.eval()
5151
return model
5252

53-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
53+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
5454
"""Get parameter name mappings between rollout and training models."""
5555
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
5656

@@ -92,9 +92,6 @@ def _complete_missing_mappings(self) -> None:
9292
# Skip weight scale parameters in mapping. Train and infer have same key.
9393
self.infer_to_train_mapping[key] = key
9494

95-
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
96-
self.infer_to_train_mapping.pop("lm_head.linear.weight")
97-
9895
def get_quantization_infer_keys(self) -> list[str]:
9996
"""Get quantization infer keys"""
10097
quant_weight_key = []
@@ -125,7 +122,7 @@ def name(self) -> str:
125122
"""name"""
126123
return "Ernie4_5_MoeForCausalLMRL"
127124

128-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
125+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
129126
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
130127
# Prepare placeholders
131128
place_holders = ["weight"]
@@ -192,7 +189,7 @@ def name(self) -> str:
192189
"""name"""
193190
return "Ernie4_5_VLMoeForConditionalGenerationRL"
194191

195-
def get_name_mappings_to_training(self, trainer_degree=2) -> Dict[str, str]:
192+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
196193
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
197194
# Prepare placeholders
198195
place_holders = ["weight"]
@@ -255,6 +252,8 @@ def _generate_ranges(start, end, step=16, take=8):
255252

256253
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
257254
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
255+
if not trainer_degree:
256+
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
258257
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
259258
# Process MoE layers
260259
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
@@ -285,7 +284,7 @@ def name(self) -> str:
285284
"""name"""
286285
return "Qwen2ForCausalLMRL"
287286

288-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
287+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
289288
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
290289
# Prepare placeholders
291290
place_holders = ["weight"]
@@ -327,7 +326,7 @@ def name(self) -> str:
327326
"""name"""
328327
return "Qwen3MoeForCausalLMRL"
329328

330-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
329+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
331330
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
332331
# Prepare placeholders
333332
place_holders = ["weight"]
@@ -394,5 +393,5 @@ def name(self) -> str:
394393
"""name"""
395394
return "Qwen3ForCausalLMRL"
396395

397-
def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]:
396+
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
398397
pass

0 commit comments

Comments
 (0)