From 90db0e7329f8cc4c060ed2cd46825831a61d08e9 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Thu, 10 Jul 2025 20:49:16 +0800 Subject: [PATCH 1/7] init shared parameters --- paddlenlp/transformers/gpt/modeling_auto.py | 7 +++-- .../transformers/gpt/modeling_auto_pp.py | 30 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index e21067ba42c3..97b42570e1a9 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -529,7 +529,7 @@ def __init__(self, config: GPTConfig, ipp=None): self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True) self.linear1.weight = dist.shard_tensor(self.linear1.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(1)]) - self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) + self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()]) self.linear2.weight = dist.shard_tensor(self.linear2.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) self.linear2.bias = dist.shard_tensor(self.linear2.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()]) # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm() @@ -658,7 +658,7 @@ def __init__( config.hidden_size, ) self.word_embeddings.weight = dist.shard_tensor( - self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()] + self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] ) self.position_embeddings.weight = dist.shard_tensor( self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] @@ -699,6 +699,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None): # The 'with' block ensures the correct seed context is used with seed_guard_context(current_seed): embeddings = self.dropout(embeddings) + embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Replicate()]) return embeddings @@ -1176,7 +1177,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None): shape=[config.vocab_size, config.hidden_size], dtype=paddle.get_default_dtype(), ) - self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]) + self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]) def forward(self, hidden_states, tensor_parallel_output=None): diff --git a/paddlenlp/transformers/gpt/modeling_auto_pp.py b/paddlenlp/transformers/gpt/modeling_auto_pp.py index 010638cdf863..23e73627b49c 100644 --- a/paddlenlp/transformers/gpt/modeling_auto_pp.py +++ b/paddlenlp/transformers/gpt/modeling_auto_pp.py @@ -139,6 +139,13 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree): layer_lists = model.layers + + shared_params_names = { + "gpt_shared_weight": ["embedding_0.w_0.dist", "gptlm_head_auto_0.w_0.dist"] + } + + shared_mp = build_shared_param_map(model, shared_params_names) + def _build_stage(model, stage_idx, group): new_model = None if stage_idx == 0: @@ -151,7 +158,7 @@ def _build_stage(model, stage_idx, group): new_model = GPTChunk( layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False ) - stage = PipelineStage(new_model, stage_idx, chunk_num, group=group) + stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_map=shared_mp) return stage stages = [] @@ -160,6 +167,27 @@ def _build_stage(model, stage_idx, group): stages.append(stage) return stages +def build_shared_param_map(model, shared_params_names): + shared_mp = [] + for key, pair in shared_params_names.items(): + assert len(pair) == 2, ( + "Only exactly two parameters are supported for sharing." + ) + ori_name = pair[0] + sync_name = pair[1] + ori_param = get_param_from_name(ori_name, model) + sync_param = get_param_from_name(sync_name, model) + shared_mp.append({ + "params": [ori_param, sync_param] + }) + return shared_mp + +def get_param_from_name(param_name, model): + for param in model.parameters(): + if param.name == param_name: + return param + raise ValueError(f"{param_name} not found in model parameters") + def get_gpt_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group): assert mode in ["VPP", "1F1B", "FThenB"] From 18b121b741578c74683ee62dcb9408f1118346a9 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Thu, 10 Jul 2025 20:55:19 +0800 Subject: [PATCH 2/7] update shared_parameters --- paddlenlp/transformers/gpt/modeling_auto_pp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto_pp.py b/paddlenlp/transformers/gpt/modeling_auto_pp.py index 23e73627b49c..7368aaeb647b 100644 --- a/paddlenlp/transformers/gpt/modeling_auto_pp.py +++ b/paddlenlp/transformers/gpt/modeling_auto_pp.py @@ -158,7 +158,7 @@ def _build_stage(model, stage_idx, group): new_model = GPTChunk( layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False ) - stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_map=shared_mp) + stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_parameters=shared_mp) return stage stages = [] From 080e701b5363799658a2262e674e820b04a49afe Mon Sep 17 00:00:00 2001 From: xuexixi Date: Tue, 15 Jul 2025 11:30:17 +0800 Subject: [PATCH 3/7] fix status --- paddlenlp/transformers/gpt/modeling_auto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index 97b42570e1a9..ab3b983a6152 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -529,7 +529,7 @@ def __init__(self, config: GPTConfig, ipp=None): self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True) self.linear1.weight = dist.shard_tensor(self.linear1.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(1)]) - self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()]) + self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) self.linear2.weight = dist.shard_tensor(self.linear2.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)]) self.linear2.bias = dist.shard_tensor(self.linear2.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()]) # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm() @@ -699,7 +699,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None): # The 'with' block ensures the correct seed context is used with seed_guard_context(current_seed): embeddings = self.dropout(embeddings) - embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Replicate()]) + embeddings = dist.reshard(embeddings, get_mesh(), [dist.Shard(0), dist.Replicate()]) return embeddings From 4b3b571186d172796ecf8365a20c1902fbb5f802 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Wed, 16 Jul 2025 17:23:18 +0800 Subject: [PATCH 4/7] update shard param placements --- paddlenlp/transformers/gpt/modeling_auto.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index ab3b983a6152..8451dfcc3fb6 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -658,10 +658,10 @@ def __init__( config.hidden_size, ) self.word_embeddings.weight = dist.shard_tensor( - self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] + self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)] ) self.position_embeddings.weight = dist.shard_tensor( - self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] + self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()] ) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1177,7 +1177,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None): shape=[config.vocab_size, config.hidden_size], dtype=paddle.get_default_dtype(), ) - self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]) + self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]) def forward(self, hidden_states, tensor_parallel_output=None): From e4f88f483425bec9290925358451aafc4b4d7ff1 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Wed, 16 Jul 2025 17:35:37 +0800 Subject: [PATCH 5/7] update shard param placements --- paddlenlp/transformers/gpt/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index 8451dfcc3fb6..5ce107e511d1 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -661,7 +661,7 @@ def __init__( self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)] ) self.position_embeddings.weight = dist.shard_tensor( - self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()] + self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(0)] ) self.dropout = nn.Dropout(config.hidden_dropout_prob) From 49ba28cd8ce2462c6a3f4b520d5ff4cb77efa0da Mon Sep 17 00:00:00 2001 From: xuexixi Date: Thu, 17 Jul 2025 11:17:24 +0800 Subject: [PATCH 6/7] fix lint --- paddlenlp/transformers/gpt/modeling_auto_pp.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/paddlenlp/transformers/gpt/modeling_auto_pp.py b/paddlenlp/transformers/gpt/modeling_auto_pp.py index 7368aaeb647b..1fe6a4a841c6 100644 --- a/paddlenlp/transformers/gpt/modeling_auto_pp.py +++ b/paddlenlp/transformers/gpt/modeling_auto_pp.py @@ -139,10 +139,7 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree): layer_lists = model.layers - - shared_params_names = { - "gpt_shared_weight": ["embedding_0.w_0.dist", "gptlm_head_auto_0.w_0.dist"] - } + shared_params_names = {"gpt_shared_weight": ["embedding_0.w_0.dist", "gptlm_head_auto_0.w_0.dist"]} shared_mp = build_shared_param_map(model, shared_params_names) @@ -167,21 +164,19 @@ def _build_stage(model, stage_idx, group): stages.append(stage) return stages + def build_shared_param_map(model, shared_params_names): shared_mp = [] for key, pair in shared_params_names.items(): - assert len(pair) == 2, ( - "Only exactly two parameters are supported for sharing." - ) + assert len(pair) == 2, "Only exactly two parameters are supported for sharing." ori_name = pair[0] sync_name = pair[1] ori_param = get_param_from_name(ori_name, model) sync_param = get_param_from_name(sync_name, model) - shared_mp.append({ - "params": [ori_param, sync_param] - }) + shared_mp.append({"params": [ori_param, sync_param]}) return shared_mp + def get_param_from_name(param_name, model): for param in model.parameters(): if param.name == param_name: From 20c3b849ab267a3ab10c45694d6779330d889f66 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Thu, 17 Jul 2025 17:37:13 +0800 Subject: [PATCH 7/7] update loss base --- scripts/distribute/ci_case_auto.sh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index b9e7728a99f2..96fc1372163b 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -2506,11 +2506,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.55853653 # output of dropout is different after supporting spmd + loss_base=10.55727577 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.56019211 # after add dropout spmd + loss_base=10.56668472 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -2578,11 +2578,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { ips=-1 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" - loss_base=10.5657959 # output of dropout is different after supporting spmd + loss_base=10.49585533 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.5760107 # after add dropout spmd + loss_base=10.51038742 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -2651,11 +2651,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.59993172 # note: need to debug - loss_base=10.57174778 # output of dropout is different after supporting spmd + loss_base=10.49603939 # output of dropout is different after supporting spmd ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.57701015 # after add dropout spmd + loss_base=10.51580238 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" @@ -2724,11 +2724,11 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5" # loss_base=10.58456802 # note: need to debug - loss_base=10.57304478 + loss_base=10.49809837 ips_base=-1 mem_base=-1 if [ $IS_A100 -ne 0 ];then - loss_base=10.57861042 # after add dropout spmd + loss_base=10.51762962 # after add dropout spmd fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ==========="