Skip to content

Commit 2c02199

Browse files
authored
fix sharding reshard bug (#10613)
1 parent 494a8a1 commit 2c02199

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def exclude_paramters_in_state_dict(
123123
)
124124
# allgather parameter names in sharding group
125125
tmp = []
126-
paddle.distributed.all_gather_object(tmp, param_names_in_master_weights, group=sharding_group)
126+
if sharding_group.nranks > 1:
127+
paddle.distributed.all_gather_object(tmp, param_names_in_master_weights, group=sharding_group)
128+
else:
129+
tmp = [param_names_in_master_weights]
127130
param_names_in_master_weights = set([v for item in tmp for v in item])
128131
logger.info("sharding_group_param_names:{}".format(param_names_in_master_weights))
129132
non_parameters_state_dict = copy.copy(model_state_dict)

0 commit comments

Comments
 (0)