Skip to content

Commit ea4fd14

Browse files
authored
Cherry-pick some PRs from PaddleNLP (#2821)
1 parent 8029b9c commit ea4fd14

File tree

5 files changed

+65
-30
lines changed

5 files changed

+65
-30
lines changed

paddleformers/trainer/trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def __init__(
394394
self.model,
395395
self.optimizer,
396396
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
397+
is_ema=self.args.sharded_model_from_ema,
397398
)
398399
if self.args.unified_checkpoint:
399400
self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args)
@@ -836,9 +837,16 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
836837
if resume_from_checkpoint is not None:
837838
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
838839
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
840+
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
841+
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
842+
else:
843+
success, err_msg = True, None
839844
if os.path.exists(path):
840-
logger.info(f"ZCC EMA load from {path}")
841-
self.zcc_manager.set_ema_state_dict(path)
845+
if success:
846+
logger.info(f"ZCC EMA load from {path}")
847+
self.zcc_manager.set_ema_state_dict(path)
848+
else:
849+
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
842850
else:
843851
logger.info(f"ZCC EMA state dict not found, in: {path}")
844852

paddleformers/trainer/training_args.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,11 @@ class TrainingArguments:
634634
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
635635
)
636636

637+
sharded_model_from_ema: bool = field(
638+
default=False,
639+
metadata={"help": "Whether to load sharded model from EMA."},
640+
)
641+
637642
tensor_parallel_degree: int = field(
638643
default=-1,
639644
metadata={
@@ -2504,9 +2509,7 @@ def should_save_sharding_stage1_model(self):
25042509
def should_load_sharding_stage1_model(self):
25052510
if self.enable_auto_parallel:
25062511
return False
2507-
return (
2508-
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model
2509-
)
2512+
return self.load_sharded_model
25102513

25112514
@property
25122515
def should_load_dataset(self):

paddleformers/trainer/utils/reshard/common.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def convert_opt_name_to_tname(tensor_names, opt_names):
102102
opt_to_t[t] = t[: -len(s)]
103103
_find = True
104104
break
105-
assert _find
105+
assert _find, t
106106
return opt_to_t
107107

108108

@@ -609,12 +609,13 @@ def map_func(weight):
609609
weight = weight.numpy()
610610
return weight
611611

612+
group_rank = max(group.rank, 0)
612613
state_dict = {k: map_func(v) for (k, v) in state_dict.items()}
613614

614615
meta_dict = {}
615616
for (k, v) in state_dict.items():
616617
# src rank
617-
meta_dict[k] = (v.dtype, v.shape, group.rank)
618+
meta_dict[k] = (v.dtype, v.shape, group_rank)
618619

619620
meta_dict_list = all_gather_simple_object(meta_dict, group)
620621

@@ -628,20 +629,21 @@ def map_func(weight):
628629
meta_list = sorted(meta_list, key=lambda x: x[0])
629630
for (k, meta) in meta_list:
630631
dtype, shape, rank = meta
631-
if rank == group.rank:
632+
if rank == group_rank:
632633
assert k in state_dict
633634
tensor = paddle.to_tensor(state_dict[k])
634635
del state_dict[k]
635636
else:
636637
tensor = paddle.to_tensor(np.empty(shape, dtype))
637638
logger.info(f"broadcast {k} from {rank}, group {group}")
638639
# broadcast the tensor
639-
paddle.distributed.broadcast(
640-
tensor,
641-
src=group.ranks[rank],
642-
group=group,
643-
sync_op=True,
644-
)
640+
if group.nranks > 1:
641+
paddle.distributed.broadcast(
642+
tensor,
643+
src=group.ranks[rank],
644+
group=group,
645+
sync_op=True,
646+
)
645647
if filter_func(k):
646648
res[k] = tensor.cpu()
647649
del tensor

paddleformers/trainer/utils/reshard/sharding_v1.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,27 @@
1717
)
1818

1919
from ....transformers.model_utils import unwrap_optimizer
20+
from .common import is_sharding_opt
2021

2122

2223
def shard(node_model_state, model, optimizer):
2324
cur_rank = max(node_model_state.group.rank, 0)
24-
optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer)
25-
assert optimizer is not None
26-
param2rank = optimizer._param2rank
27-
28-
def filter_func(key):
29-
names = key
30-
param_name = names[1]
31-
assert param_name in param2rank
32-
dst_rank = param2rank[param_name]
33-
return dst_rank == cur_rank
25+
unwrapped_optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer)
26+
if unwrapped_optimizer is not None:
27+
optimizer = unwrapped_optimizer
28+
assert not is_sharding_opt(optimizer)
29+
param2rank = optimizer._param2rank
30+
31+
def filter_func(key):
32+
names = key
33+
param_name = names[1]
34+
assert param_name in param2rank
35+
dst_rank = param2rank[param_name]
36+
return dst_rank == cur_rank
37+
38+
else:
39+
assert not is_sharding_opt(optimizer)
40+
filter_func = lambda key: True
3441

3542
node_model_state.reshard(filter_func)
3643
return node_model_state

paddleformers/trainer/utils/sharding_io.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def get_group_ids(self):
269269

270270

271271
class ShardingIO:
272-
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False):
272+
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False, is_ema=False):
273273
self.args = args
274274
self.model = model
275275
self.optimizer = optimizer
@@ -281,6 +281,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
281281

282282
self.remap_parameter_name = remap_parameter_name
283283
self.remapper = None
284+
self.is_ema = is_ema
284285

285286
def _get_remapper(self, checkpoint):
286287
if not self.remap_parameter_name:
@@ -351,7 +352,9 @@ def load_model_slices():
351352
structure_name_map = split_structure_name_mapping(structure_name_map, group_getter)
352353
for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree):
353354
tmp = self._load_one_state_dict_from_checkpoint(
354-
checkpoint, base_weight_name, self.args.sharded_name_suffix(i, j)
355+
checkpoint,
356+
base_weight_name,
357+
self.args.sharded_name_suffix(i, j, sharding_parallel_degree=sharding_degree),
355358
)
356359
tmp = split_model_state(tmp, group_getter)
357360
for gid in gids:
@@ -399,24 +402,33 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
399402
"""
400403
load state_dict of one shard from_checkpoint, Only load model state dict.
401404
"""
405+
if self.is_ema:
406+
base_weight_name = base_weight_name.replace("model_state", "ema").replace("pdparams", "pdopt")
402407
file_path = os.path.join(resume_from_checkpoint, _add_variant(base_weight_name, weight_name_suffix))
403408
if not os.path.isfile(file_path):
404409
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")
405410

406411
logger.info(f"Loading model from {resume_from_checkpoint} .")
407412
# We load the model state dict on the CPU to avoid an OOM error.
408413
state_dict = paddle.load(file_path, return_numpy=True)
414+
if self.is_ema:
415+
state_dict.pop("master_weights", None)
409416
state_dict = self._remap_parameter_name(resume_from_checkpoint, state_dict, is_opt=False)
410417
return state_dict
411418

412419
def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimizer_name_suffix, group_getter=None):
420+
if self.is_ema:
421+
base_opt_name = base_opt_name.replace("optimizer", "ema")
413422
optimizer_name = _add_variant(base_opt_name, optimizer_name_suffix)
414423
path = os.path.join(checkpoint, optimizer_name)
415424
logger.info(f"load optimizer state from {path}")
416425
if os.path.isfile(path):
426+
opt_state = paddleformers_load(path, map_location="cpu")
427+
if self.is_ema:
428+
opt_state = {"master_weights": opt_state.get("master_weights", {})}
417429
return self._remap_parameter_name(
418430
checkpoint,
419-
self._modify_ckpt_for_compatibility(paddleformers_load(path, map_location="cpu")),
431+
self._modify_ckpt_for_compatibility(opt_state),
420432
is_opt=True,
421433
)
422434
logger.info(f"{path} not exists")
@@ -449,9 +461,12 @@ def _need_reshard(self, checkpoint):
449461
if sharding_strategy == SHARDING_STRATEGY_V1:
450462
param2rank = sharding_meta["param2rank"]
451463
optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer)
452-
assert optimizer
453-
if len(param2rank) == 0:
454-
logger.warning("The param2rank is empty. Force reshard would be performed.")
464+
if self.args.sharding_parallel_degree > 1:
465+
assert optimizer is not None
466+
else:
467+
assert optimizer is None
468+
if len(param2rank) == 0 or optimizer is None:
469+
logger.warning("The param2rank is empty or sharding degree is 1. Force reshard would be performed.")
455470
return True
456471
assert len(param2rank) == len(optimizer._param2rank)
457472
for (k, v) in param2rank.items():

0 commit comments

Comments
 (0)