@@ -529,7 +529,7 @@ def __init__(self, config: GPTConfig, ipp=None):
529529 self .linear2 = nn .Linear (config .intermediate_size , config .hidden_size , bias_attr = True )
530530
531531 self .linear1 .weight = dist .shard_tensor (self .linear1 .weight , get_mesh (ipp ), [dist .Replicate (), dist .Shard (1 )])
532- self .linear1 .bias = dist .shard_tensor (self .linear1 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Shard ( 0 )])
532+ self .linear1 .bias = dist .shard_tensor (self .linear1 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Replicate ( )])
533533 self .linear2 .weight = dist .shard_tensor (self .linear2 .weight , get_mesh (ipp ), [dist .Replicate (), dist .Shard (0 )])
534534 self .linear2 .bias = dist .shard_tensor (self .linear2 .bias , get_mesh (ipp ), [dist .Replicate (), dist .Replicate ()])
535535 # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm()
@@ -658,7 +658,7 @@ def __init__(
658658 config .hidden_size ,
659659 )
660660 self .word_embeddings .weight = dist .shard_tensor (
661- self .word_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Replicate ( )]
661+ self .word_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Shard ( 1 )]
662662 )
663663 self .position_embeddings .weight = dist .shard_tensor (
664664 self .position_embeddings .weight , get_mesh (), [dist .Replicate (), dist .Shard (1 )]
@@ -699,6 +699,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
699699 # The 'with' block ensures the correct seed context is used
700700 with seed_guard_context (current_seed ):
701701 embeddings = self .dropout (embeddings )
702+ embeddings = dist .reshard (embeddings , get_mesh (), [dist .Replicate (), dist .Replicate ()])
702703 return embeddings
703704
704705
@@ -1176,7 +1177,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
11761177 shape = [config .vocab_size , config .hidden_size ],
11771178 dtype = paddle .get_default_dtype (),
11781179 )
1179- self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (0 )])
1180+ self .weight = dist .shard_tensor (self .weight , get_mesh (self .ipp ), [dist .Replicate (), dist .Shard (1 )])
11801181
11811182 def forward (self , hidden_states , tensor_parallel_output = None ):
11821183
0 commit comments