From d2e8f401c5e0b2914c6ea358ec3353b621a8b190 Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Fri, 18 Jul 2025 18:53:31 +0800 Subject: [PATCH 1/3] support trainer_degree in name_mapping --- fastdeploy/rl/rollout_model.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 199ec3a61b..d6ff13389c 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -50,9 +50,9 @@ def _init_model(self) -> nn.Layer: model.eval() return model - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" - return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})() + return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree) def get_quantization_infer_keys(self) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" @@ -125,7 +125,7 @@ def name(self) -> str: """name""" return "Ernie4_5_MoeForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -192,7 +192,7 @@ def name(self) -> str: """name""" return "Ernie4_5_VLMoeForConditionalGenerationRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -255,9 +255,7 @@ def _generate_ranges(start, end, step=16, take=8): assert isinstance(self.fd_config.model_config.moe_num_experts, list) total_moe_num = sum(self.fd_config.model_config.moe_num_experts) - rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size - expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // rollout_model_degree - + expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree # Process MoE layers for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index): _add_expert_mappings(layer_idx, "text", expert_start=0) @@ -287,7 +285,7 @@ def name(self) -> str: """name""" return "Qwen2ForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -329,7 +327,7 @@ def name(self) -> str: """name""" return "Qwen3MoeForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -396,5 +394,5 @@ def name(self) -> str: """name""" return "Qwen3ForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: - pass + def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + pass \ No newline at end of file From e82fd62e9b733e1091449607035968b830086d6c Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Mon, 21 Jul 2025 11:33:21 +0800 Subject: [PATCH 2/3] fix --- fastdeploy/rl/rollout_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index d6ff13389c..930b6c39c6 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -192,7 +192,7 @@ def name(self) -> str: """name""" return "Ernie4_5_VLMoeForConditionalGenerationRL" - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=2) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] From 529b22147fa49c8a0610cdf77bcd1c188a2a2dcc Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Mon, 21 Jul 2025 11:39:09 +0800 Subject: [PATCH 3/3] fix --- fastdeploy/rl/rollout_model.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 930b6c39c6..3f21862ea8 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -50,7 +50,7 @@ def _init_model(self) -> nn.Layer: model.eval() return model - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree) @@ -92,9 +92,6 @@ def _complete_missing_mappings(self) -> None: # Skip weight scale parameters in mapping. Train and infer have same key. self.infer_to_train_mapping[key] = key - if getattr(self.fd_config.model_config, "tie_word_embeddings", False): - self.infer_to_train_mapping.pop("lm_head.linear.weight") - def get_quantization_infer_keys(self) -> list[str]: """Get quantization infer keys""" quant_weight_key = [] @@ -125,7 +122,7 @@ def name(self) -> str: """name""" return "Ernie4_5_MoeForCausalLMRL" - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -192,7 +189,7 @@ def name(self) -> str: """name""" return "Ernie4_5_VLMoeForConditionalGenerationRL" - def get_name_mappings_to_training(self, trainer_degree=2) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -255,6 +252,8 @@ def _generate_ranges(start, end, step=16, take=8): assert isinstance(self.fd_config.model_config.moe_num_experts, list) total_moe_num = sum(self.fd_config.model_config.moe_num_experts) + if not trainer_degree: + trainer_degree = self.fd_config.parallel_config.tensor_parallel_size expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree # Process MoE layers for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index): @@ -285,7 +284,7 @@ def name(self) -> str: """name""" return "Qwen2ForCausalLMRL" - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -327,7 +326,7 @@ def name(self) -> str: """name""" return "Qwen3MoeForCausalLMRL" - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -394,5 +393,5 @@ def name(self) -> str: """name""" return "Qwen3ForCausalLMRL" - def get_name_mappings_to_training(self, trainer_degree=1) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: pass \ No newline at end of file