Skip to content

Commit 454ed42

Browse files
authored
fix gpt sp model
1 parent 6ee8fe0 commit 454ed42

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

paddlenlp/transformers/gpt/modeling_auto.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
from paddle.distributed.fleet.utils import recompute
3333
from paddle.utils import try_import
3434

35-
try:
36-
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
37-
mark_as_sequence_parallel_parameter,
38-
)
39-
except:
40-
pass
41-
4235
from ...utils.converter import StateDictNameMapping
4336
from .. import PretrainedModel, register_base_model
4437
from ..model_outputs import BaseModelOutputWithPastAndCrossAttentions
@@ -209,19 +202,19 @@ def __init__(self, config, ipp=None):
209202
)
210203

211204
def _fuse_prepare_qkv(self, query, use_cache=False, past_key_value=None):
212-
if self.config.sequence_parallel:
213-
# [bs, seq_len, num_head * head_dim] -> [bs / n, seq_len, num_head, head_dim] (n is model parallelism)
214-
target_shape = [-1, self.config.seq_length, self.num_attention_heads, 3 * self.head_dim]
215-
else:
216-
target_shape = [0, 0, self.num_attention_heads, 3 * self.head_dim]
217-
205+
target_shape = [0, 0, self.num_attention_heads, 3 * self.head_dim]
218206
# bs, seq_len, num_head * 3*head_dim
219207
mix_layer = self.qkv_proj(query)
220208
# bs, seq_len, num_head, 3*head_dim
221209
mix_layer = paddle.reshape_(mix_layer, target_shape)
222210
# query_states, key_states, value_states => bs, seq_len, num_head, head_dim
223211
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
224-
212+
if self.config.sequence_parallel:
213+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
214+
# FA and rope not support sequence first
215+
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
216+
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
217+
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
225218
# [bs, seq_len, num_head, head_dim]
226219
if past_key_value is not None:
227220
# reuse k, v, self_attention
@@ -326,6 +319,8 @@ def forward(
326319
Applies multi-head attention to map queries and a set of key-value pairs
327320
to outputs.
328321
"""
322+
if self.config.sequence_parallel:
323+
query = dist.reshard(query, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()])
329324
key = query if key is None else key
330325
value = query if value is None else value
331326
if self.config.fuse_attention_qkv:
@@ -363,11 +358,11 @@ def forward(
363358
# else their shape are [bs, q_len, num_head * head_dim / n], n is mp parallelism.
364359

365360
if self.config.sequence_parallel:
366-
bs, seq_len, dim = out.shape
367-
out = out.reshape([bs * seq_len, dim]) # [bs, seq_len, dim / n] => [bs * seq_len, dim / n]
368-
361+
out = paddle.transpose(out, [1, 0, 2])
369362
# project to output
370363
out = self.out_proj(out)
364+
if self.config.sequence_parallel:
365+
out = dist.reshard(out, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)])
371366
# if sequence_parallel is true, out shape are [bs * seq_len / n, dim]
372367
# else their shape are [bs, seq_len, dim], n is mp parallelism.
373368
outs = [out]
@@ -390,9 +385,6 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
390385
self.layers = decoder_layers
391386

392387
self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
393-
if config.sequence_parallel:
394-
mark_as_sequence_parallel_parameter(self.norm.weight)
395-
mark_as_sequence_parallel_parameter(self.norm.bias)
396388

397389
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
398390
# Enable_recompute defaults to False and is controlled by Trainer
@@ -529,7 +521,7 @@ def __init__(self, config: GPTConfig, ipp=None):
529521
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
530522

531523
self.linear1.weight = dist.shard_tensor(self.linear1.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(1)])
532-
self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])
524+
self.linear1.bias = dist.shard_tensor(self.linear1.bias, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)])
533525
self.linear2.weight = dist.shard_tensor(self.linear2.weight, get_mesh(ipp), [dist.Replicate(), dist.Shard(0)])
534526
self.linear2.bias = dist.shard_tensor(self.linear2.bias, get_mesh(ipp), [dist.Replicate(), dist.Replicate()])
535527
# fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm()
@@ -588,6 +580,12 @@ def forward(
588580

589581
# Use a ternary operator for a more concise assignment of current_seed
590582
current_seed = "local_seed" if self.config.sequence_parallel else "global_seed"
583+
if self.config.sequence_parallel:
584+
hidden_states = dist.reshard(
585+
hidden_states,
586+
get_mesh(self.ipp),
587+
[dist.Shard(1), dist.Shard(0)],
588+
)
591589

592590
# The 'with' block ensures the correct seed context is used
593591
with seed_guard_context(current_seed):
@@ -602,14 +600,17 @@ def forward(
602600
residual = hidden_states
603601
if self.config.normalize_before:
604602
hidden_states = self.norm2(hidden_states)
605-
603+
if self.config.sequence_parallel:
604+
hidden_states = dist.reshard(hidden_states, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()])
606605
# when sequence_parallel=True:
607606
# hidden_states => [bs * seq_len / n, embed_dim]
608607
with seed_guard_context(current_seed):
609608
if not self.config.use_fused_dropout_add:
610609
l_1 = self.linear1(hidden_states)
611610
act = self.activation(l_1, approximate=True)
612611
l_2 = self.linear2(act)
612+
if self.config.sequence_parallel:
613+
l_2 = dist.reshard(l_2, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)])
613614
hidden_states = residual + self.dropout2(l_2)
614615
else:
615616
hidden_states = self.fused_dropout_add2(
@@ -680,18 +681,15 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
680681
position_embeddings = self.position_embeddings(position_ids)
681682
embeddings = inputs_embeddings + position_embeddings
682683

683-
# exit()
684-
if self.config.sequence_parallel:
685-
# embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()])
686-
bs, seq_len, hidden_size = embeddings.shape
687-
# [bs, seq_len, dim] -> [bs * seq_len, dim]
688-
embeddings = paddle.reshape_(embeddings, [bs * seq_len, hidden_size])
689-
# [bs * seq_len / n, dim] (n is mp parallelism)
690-
# embeddings = ScatterOp.apply(embeddings)
691-
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Shard(1)])
692684
# Use a ternary operator for a more concise assignment of current_seed
693685
current_seed = "local_seed" if self.config.sequence_parallel else "global_seed"
694686
# The 'with' block ensures the correct seed context is used
687+
if self.config.sequence_parallel:
688+
# [B, S, H] -> [S, B, H]
689+
embeddings = paddle.transpose(embeddings, [1, 0, 2])
690+
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Shard(1), dist.Shard(0)])
691+
else:
692+
embeddings = dist.reshard(embeddings, get_mesh(), [dist.Shard(0), dist.Replicate()])
695693
with seed_guard_context(current_seed):
696694
embeddings = self.dropout(embeddings)
697695
return embeddings
@@ -1171,13 +1169,16 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
11711169
shape=[config.vocab_size, config.hidden_size],
11721170
dtype=paddle.get_default_dtype(),
11731171
)
1174-
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
1172+
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
11751173

11761174
def forward(self, hidden_states, tensor_parallel_output=None):
1177-
11781175
if self.config.sequence_parallel:
1179-
hidden_states = dist.reshard(hidden_states, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()])
1180-
hidden_states = paddle.reshape(hidden_states, [-1, self.config.seq_length, self.config.hidden_size])
1176+
hidden_states = dist.reshard(
1177+
hidden_states,
1178+
get_mesh(self.ipp),
1179+
[dist.Shard(1), dist.Shard(0)],
1180+
)
1181+
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
11811182

11821183
if tensor_parallel_output is None:
11831184
tensor_parallel_output = self.config.tensor_parallel_output

0 commit comments

Comments
 (0)