We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 494a8a1 commit 2c02199Copy full SHA for 2c02199
paddlenlp/trainer/utils/sharding_io.py
@@ -123,7 +123,10 @@ def exclude_paramters_in_state_dict(
123
)
124
# allgather parameter names in sharding group
125
tmp = []
126
- paddle.distributed.all_gather_object(tmp, param_names_in_master_weights, group=sharding_group)
+ if sharding_group.nranks > 1:
127
+ paddle.distributed.all_gather_object(tmp, param_names_in_master_weights, group=sharding_group)
128
+ else:
129
+ tmp = [param_names_in_master_weights]
130
param_names_in_master_weights = set([v for item in tmp for v in item])
131
logger.info("sharding_group_param_names:{}".format(param_names_in_master_weights))
132
non_parameters_state_dict = copy.copy(model_state_dict)
0 commit comments