Skip to content

Commit 90db0e7

Browse files
committed
init shared parameters
1 parent 44eff1f commit 90db0e7

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

paddlenlp/transformers/gpt/modeling_auto.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def __init__(self, config: GPTConfig, ipp=None):
529529
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
530530

531531
self.linear1.weight = dist.shard_tensor(self.linear1.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(1)])
532-
self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)])
532+
self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])
533533
self.linear2.weight = dist.shard_tensor(self.linear2.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)])
534534
self.linear2.bias = dist.shard_tensor(self.linear2.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])
535535
# fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm()
@@ -658,7 +658,7 @@ def __init__(
658658
config.hidden_size,
659659
)
660660
self.word_embeddings.weight = dist.shard_tensor(
661-
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()]
661+
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
662662
)
663663
self.position_embeddings.weight = dist.shard_tensor(
664664
self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
@@ -699,6 +699,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
699699
# The 'with' block ensures the correct seed context is used
700700
with seed_guard_context(current_seed):
701701
embeddings = self.dropout(embeddings)
702+
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Replicate()])
702703
return embeddings
703704

704705

@@ -1176,7 +1177,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
11761177
shape=[config.vocab_size, config.hidden_size],
11771178
dtype=paddle.get_default_dtype(),
11781179
)
1179-
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
1180+
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
11801181

11811182
def forward(self, hidden_states, tensor_parallel_output=None):
11821183

paddlenlp/transformers/gpt/modeling_auto_pp.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree):
139139

140140
layer_lists = model.layers
141141

142+
143+
shared_params_names = {
144+
"gpt_shared_weight": ["embedding_0.w_0.dist", "gptlm_head_auto_0.w_0.dist"]
145+
}
146+
147+
shared_mp = build_shared_param_map(model, shared_params_names)
148+
142149
def _build_stage(model, stage_idx, group):
143150
new_model = None
144151
if stage_idx == 0:
@@ -151,7 +158,7 @@ def _build_stage(model, stage_idx, group):
151158
new_model = GPTChunk(
152159
layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False
153160
)
154-
stage = PipelineStage(new_model, stage_idx, chunk_num, group=group)
161+
stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_map=shared_mp)
155162
return stage
156163

157164
stages = []
@@ -160,6 +167,27 @@ def _build_stage(model, stage_idx, group):
160167
stages.append(stage)
161168
return stages
162169

170+
def build_shared_param_map(model, shared_params_names):
171+
shared_mp = []
172+
for key, pair in shared_params_names.items():
173+
assert len(pair) == 2, (
174+
"Only exactly two parameters are supported for sharing."
175+
)
176+
ori_name = pair[0]
177+
sync_name = pair[1]
178+
ori_param = get_param_from_name(ori_name, model)
179+
sync_param = get_param_from_name(sync_name, model)
180+
shared_mp.append({
181+
"params": [ori_param, sync_param]
182+
})
183+
return shared_mp
184+
185+
def get_param_from_name(param_name, model):
186+
for param in model.parameters():
187+
if param.name == param_name:
188+
return param
189+
raise ValueError(f"{param_name} not found in model parameters")
190+
163191

164192
def get_gpt_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group):
165193
assert mode in ["VPP", "1F1B", "FThenB"]

0 commit comments

Comments
 (0)