Skip to content

Commit 4952069

Browse files
authored
[GPT-3] Fix shared weights sync for PipelineLayer (#7775)
1 parent fe3c052 commit 4952069

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
)
4343
from huggingface_hub.utils import EntryNotFoundError
4444
from paddle import Tensor
45-
from paddle.distributed.fleet.meta_parallel.parallel_layers import SharedLayerDesc
45+
from paddle.distributed.fleet.meta_parallel.parallel_layers import (
46+
PipelineLayer,
47+
SharedLayerDesc,
48+
)
4649
from paddle.nn import Embedding, Layer
4750

4851
# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later
@@ -935,6 +938,18 @@ def _post_init(self, original_init, *args, **kwargs):
935938
):
936939
self.init_weights()
937940

941+
# Note:
942+
# 1. PipelineLayer will create parameters for each layer and
943+
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
944+
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
945+
# synchronize the shared parameters.
946+
# However, `self._init_weights` will re-initialize the parameters without
947+
# synchronizing the shared parameters. If the following step does not load a checkpoint,
948+
# the shared parameters will be different.
949+
950+
if isinstance(self, PipelineLayer):
951+
self._synchronize_shared_weights()
952+
938953
def _init_weights(self, layer):
939954
"""
940955
Initialize the weights. This method should be overridden by derived class.

0 commit comments

Comments
 (0)