3232from paddle .distributed .fleet .utils import recompute
3333from paddle .utils import try_import
3434
35- try :
36- from paddle .distributed .fleet .utils .sequence_parallel_utils import (
37- mark_as_sequence_parallel_parameter ,
38- )
39- except :
40- pass
41-
4235from ...utils .converter import StateDictNameMapping
4336from .. import PretrainedModel , register_base_model
4437from ..model_outputs import BaseModelOutputWithPastAndCrossAttentions
@@ -209,19 +202,19 @@ def __init__(self, config, ipp=None):
209202 )
210203
211204 def _fuse_prepare_qkv (self , query , use_cache = False , past_key_value = None ):
212- if self .config .sequence_parallel :
213- # [bs, seq_len, num_head * head_dim] -> [bs / n, seq_len, num_head, head_dim] (n is model parallelism)
214- target_shape = [- 1 , self .config .seq_length , self .num_attention_heads , 3 * self .head_dim ]
215- else :
216- target_shape = [0 , 0 , self .num_attention_heads , 3 * self .head_dim ]
217-
205+ target_shape = [0 , 0 , self .num_attention_heads , 3 * self .head_dim ]
218206 # bs, seq_len, num_head * 3*head_dim
219207 mix_layer = self .qkv_proj (query )
220208 # bs, seq_len, num_head, 3*head_dim
221209 mix_layer = paddle .reshape_ (mix_layer , target_shape )
222210 # query_states, key_states, value_states => bs, seq_len, num_head, head_dim
223211 query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
224-
212+ if self .config .sequence_parallel :
213+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
214+ # FA and rope not support sequence first
215+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
216+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
217+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
225218 # [bs, seq_len, num_head, head_dim]
226219 if past_key_value is not None :
227220 # reuse k, v, self_attention
@@ -326,6 +319,8 @@ def forward(
326319 Applies multi-head attention to map queries and a set of key-value pairs
327320 to outputs.
328321 """
322+ if self .config .sequence_parallel :
323+ query = dist .reshard (query , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Replicate ()])
329324 key = query if key is None else key
330325 value = query if value is None else value
331326 if self .config .fuse_attention_qkv :
@@ -363,11 +358,11 @@ def forward(
363358 # else their shape are [bs, q_len, num_head * head_dim / n], n is mp parallelism.
364359
365360 if self .config .sequence_parallel :
366- bs , seq_len , dim = out .shape
367- out = out .reshape ([bs * seq_len , dim ]) # [bs, seq_len, dim / n] => [bs * seq_len, dim / n]
368-
361+ out = paddle .transpose (out , [1 , 0 , 2 ])
369362 # project to output
370363 out = self .out_proj (out )
364+ if self .config .sequence_parallel :
365+ out = dist .reshard (out , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Shard (0 )])
371366 # if sequence_parallel is true, out shape are [bs * seq_len / n, dim]
372367 # else their shape are [bs, seq_len, dim], n is mp parallelism.
373368 outs = [out ]
@@ -390,9 +385,6 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
390385 self .layers = decoder_layers
391386
392387 self .norm = GPTLayerNorm (config , config .hidden_size , epsilon = 1e-5 )
393- if config .sequence_parallel :
394- mark_as_sequence_parallel_parameter (self .norm .weight )
395- mark_as_sequence_parallel_parameter (self .norm .bias )
396388
397389 # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
398390 # Enable_recompute defaults to False and is controlled by Trainer
@@ -529,7 +521,7 @@ def __init__(self, config: GPTConfig, ipp=None):
529521 self .linear2 = nn .Linear (config .intermediate_size , config .hidden_size , bias_attr = True )
530522
531523 self .linear1 .weight = dist .shard_tensor (self .linear1 .weight , get_mesh (ipp ), [dist .Replicate (), dist .Shard (1 )])
532- self .linear1 .bias = dist .shard_tensor (self .linear1 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Replicate ( )])
524+ self .linear1 .bias = dist .shard_tensor (self .linear1 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Shard ( 0 )])
533525 self .linear2 .weight = dist .shard_tensor (self .linear2 .weight , get_mesh (ipp ), [dist .Replicate (), dist .Shard (0 )])
534526 self .linear2 .bias = dist .shard_tensor (self .linear2 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Replicate ()])
535527 # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm()
@@ -588,6 +580,12 @@ def forward(
588580
589581 # Use a ternary operator for a more concise assignment of current_seed
590582 current_seed = "local_seed" if self .config .sequence_parallel else "global_seed"
583+ if self .config .sequence_parallel :
584+ hidden_states = dist .reshard (
585+ hidden_states ,
586+ get_mesh (self .ipp ),
587+ [dist .Shard (1 ), dist .Shard (0 )],
588+ )
591589
592590 # The 'with' block ensures the correct seed context is used
593591 with seed_guard_context (current_seed ):
@@ -602,14 +600,17 @@ def forward(
602600 residual = hidden_states
603601 if self .config .normalize_before :
604602 hidden_states = self .norm2 (hidden_states )
605-
603+ if self .config .sequence_parallel :
604+ hidden_states = dist .reshard (hidden_states , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Replicate ()])
606605 # when sequence_parallel=True:
607606 # hidden_states => [bs * seq_len / n, embed_dim]
608607 with seed_guard_context (current_seed ):
609608 if not self .config .use_fused_dropout_add :
610609 l_1 = self .linear1 (hidden_states )
611610 act = self .activation (l_1 , approximate = True )
612611 l_2 = self .linear2 (act )
612+ if self .config .sequence_parallel :
613+ l_2 = dist .reshard (l_2 , get_mesh (self .ipp ), [dist .Shard (1 ), dist .Shard (0 )])
613614 hidden_states = residual + self .dropout2 (l_2 )
614615 else :
615616 hidden_states = self .fused_dropout_add2 (
@@ -680,18 +681,15 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
680681 position_embeddings = self .position_embeddings (position_ids )
681682 embeddings = inputs_embeddings + position_embeddings
682683
683- # exit()
684- if self .config .sequence_parallel :
685- # embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()])
686- bs , seq_len , hidden_size = embeddings .shape
687- # [bs, seq_len, dim] -> [bs * seq_len, dim]
688- embeddings = paddle .reshape_ (embeddings , [bs * seq_len , hidden_size ])
689- # [bs * seq_len / n, dim] (n is mp parallelism)
690- # embeddings = ScatterOp.apply(embeddings)
691- embeddings = dist .reshard (embeddings , get_mesh (), [dist .Replicate (), dist .Shard (1 )])
692684 # Use a ternary operator for a more concise assignment of current_seed
693685 current_seed = "local_seed" if self .config .sequence_parallel else "global_seed"
694686 # The 'with' block ensures the correct seed context is used
687+ if self .config .sequence_parallel :
688+ # [B, S, H] -> [S, B, H]
689+ embeddings = paddle .transpose (embeddings , [1 , 0 , 2 ])
690+ embeddings = dist .reshard (embeddings , get_mesh (), [dist .Shard (1 ), dist .Shard (0 )])
691+ else :
692+ embeddings = dist .reshard (embeddings , get_mesh (), [dist .Shard (0 ), dist .Replicate ()])
695693 with seed_guard_context (current_seed ):
696694 embeddings = self .dropout (embeddings )
697695 return embeddings
@@ -1171,13 +1169,16 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
11711169 shape = [config .vocab_size , config .hidden_size ],
11721170 dtype = paddle .get_default_dtype (),
11731171 )
1174- self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (0 )])
1172+ self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (1 )])
11751173
11761174 def forward (self , hidden_states , tensor_parallel_output = None ):
1177-
11781175 if self .config .sequence_parallel :
1179- hidden_states = dist .reshard (hidden_states , get_mesh (self .ipp ), [dist .Replicate (), dist .Replicate ()])
1180- hidden_states = paddle .reshape (hidden_states , [- 1 , self .config .seq_length , self .config .hidden_size ])
1176+ hidden_states = dist .reshard (
1177+ hidden_states ,
1178+ get_mesh (self .ipp ),
1179+ [dist .Shard (1 ), dist .Shard (0 )],
1180+ )
1181+ hidden_states = paddle .transpose (hidden_states , [1 , 0 , 2 ])
11811182
11821183 if tensor_parallel_output is None :
11831184 tensor_parallel_output = self .config .tensor_parallel_output
0 commit comments