Skip to content

Commit 2ac9ac8

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP into masterweight
2 parents 74878bd + 119ed11 commit 2ac9ac8

File tree

17 files changed

+984
-448
lines changed

17 files changed

+984
-448
lines changed

csrc/gpu/unittest/test_get_padding_offset_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,5 @@ def test_get_padding_offset_v2(self):
6464
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
6565
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
6666

67-
6867
if __name__ == "__main__":
6968
unittest.main()

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 137 additions & 46 deletions
Large diffs are not rendered by default.

paddlenlp/trainer/trainer.py

Lines changed: 202 additions & 66 deletions
Large diffs are not rendered by default.

paddlenlp/trainer/trainer_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
import paddle
3838
import paddle.distributed as dist
3939
from paddle.distributed import fleet
40+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
41+
DygraphShardingOptimizer,
42+
DygraphShardingOptimizerV2,
43+
)
4044
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
4145
from paddle.io import IterableDataset
4246
from paddle.optimizer.lr import LambdaDecay
@@ -1357,3 +1361,87 @@ def set_comm_config(configs, attr, dict_obj):
13571361
set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None))
13581362
set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None))
13591363
return strategy
1364+
1365+
1366+
def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
1367+
"""
1368+
Initialize the optimizer's states according to its type.
1369+
1370+
For DygraphShardingOptimizer (V1), initializes accumulators for local parameters.
1371+
For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters.
1372+
For other cases, initializes accumulators for all parameters.
1373+
1374+
Args:
1375+
optimizer: The optimizer instance to be initialized.
1376+
"""
1377+
optimizer_state_names = [".moment1_0", ".moment2_0", ".beta1_pow_acc_0", ".beta2_pow_acc_0", ".w_0"]
1378+
inner_opt = getattr(optimizer, "_inner_opt", None)
1379+
static_to_struct_mapping = {}
1380+
model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items()))
1381+
for k, v in model_sharded_state_dict.items():
1382+
if v.local_tensor.name not in static_to_struct_mapping:
1383+
static_to_struct_mapping[v.local_tensor.name] = k
1384+
1385+
if isinstance(inner_opt, DygraphShardingOptimizer):
1386+
local_params = optimizer._rank2params[optimizer._sharding_rank]
1387+
param_list = []
1388+
for param in local_params:
1389+
param_name = param.name
1390+
struct_name = static_to_struct_mapping[param_name]
1391+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1392+
continue
1393+
param_list.append(param)
1394+
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)
1395+
return
1396+
1397+
elif isinstance(inner_opt, DygraphShardingOptimizerV2):
1398+
1399+
def init_param_optimizer_states(param_iter):
1400+
master_weights = {}
1401+
state_dict = {}
1402+
moments = ("moment1_0", "moment2_0")
1403+
betas = ("beta1_pow_acc_0", "beta2_pow_acc_0")
1404+
for static_name, shape, no_need_master_weights in param_iter:
1405+
if not no_need_master_weights:
1406+
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
1407+
prefix = f"{static_name}_fp32_master_0_"
1408+
else:
1409+
prefix = f"{static_name}_"
1410+
1411+
for moment in moments:
1412+
key = f"{prefix}{moment}"
1413+
state_dict[key] = paddle.zeros(shape, dtype="float32")
1414+
for beta in betas:
1415+
key = f"{prefix}{beta}"
1416+
state_dict[key] = paddle.zeros((1,), dtype="float32")
1417+
return master_weights, state_dict
1418+
1419+
def buffer_params():
1420+
for buffer in optimizer._comm_buffer_list:
1421+
for param_name, grad_view in buffer._sharding_param_grad_view.items():
1422+
struct_name = static_to_struct_mapping[param_name]
1423+
if not any(
1424+
struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names
1425+
):
1426+
continue
1427+
param_begin = grad_view._param_begin
1428+
param_end = grad_view._param_end
1429+
shape = (param_end - param_begin,)
1430+
no_need_master_weights = grad_view._param.dtype == paddle.float32
1431+
1432+
if shape[0] > 0:
1433+
yield param_name, shape, no_need_master_weights
1434+
1435+
master_weights, state_dict = init_param_optimizer_states(buffer_params())
1436+
state_dict["master_weights"] = master_weights
1437+
state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06}
1438+
optimizer.set_state_dict(state_dict)
1439+
return
1440+
param_list = []
1441+
for param in optimizer._parameter_list:
1442+
param_name = param.name
1443+
struct_name = static_to_struct_mapping[param_name]
1444+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1445+
continue
1446+
param_list.append(param)
1447+
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)

paddlenlp/trainer/training_args.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,12 @@ class TrainingArguments:
407407
Whether to release gradients during training. Default is `False`.
408408
ckpt_quant_stage (`str`, *optional*):
409409
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
410+
save_checkpoint_format (`str`, *optional*):
411+
Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
412+
load_checkpoint_format (`str`, *optional*):
413+
Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
414+
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
415+
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
410416
"""
411417

412418
output_dir: str = field(
@@ -941,6 +947,29 @@ class TrainingArguments:
941947
default=False,
942948
metadata={"help": "Whether to use async_save instead of paddle.save."},
943949
)
950+
save_checkpoint_format: Optional[str] = field(
951+
default=None,
952+
metadata={
953+
"help": (
954+
"Specifies the format used to save checkpoints. "
955+
"Available options: 'sharding_io', 'unified_checkpoint', "
956+
"'flex_checkpoint'."
957+
"This setting is ignored if the corresponding switch is configured."
958+
)
959+
},
960+
)
961+
962+
load_checkpoint_format: Optional[str] = field(
963+
default=None,
964+
metadata={
965+
"help": (
966+
"Specifies the format used to load checkpoints. "
967+
"Available options: 'sharding_io', 'unified_checkpoint', "
968+
"'flex_checkpoint'."
969+
"This setting is ignored if the corresponding switch is configured."
970+
)
971+
},
972+
)
944973
ordered_save_group_size: int = field(
945974
default=0,
946975
metadata={
@@ -1106,6 +1135,13 @@ class TrainingArguments:
11061135
default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"}
11071136
)
11081137

1138+
aoa_config: Optional[dict[str, list[str]]] = field(
1139+
default=None,
1140+
metadata={
1141+
"help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None."
1142+
},
1143+
)
1144+
11091145
def __post_init__(self):
11101146
world_size = paddle.distributed.get_world_size()
11111147
if in_auto_parallel_align_mode():
@@ -1210,7 +1246,8 @@ def __post_init__(self):
12101246
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")
12111247

12121248
self._post_init_parallel_degree()
1213-
1249+
self._post_init_save_checkpoint_format()
1250+
self._post_init_load_checkpoint_format()
12141251
if self.to_static:
12151252
assert world_size == 1 or self.enable_auto_parallel, (
12161253
"It's not supported for training in static mode except the following cases : "
@@ -1864,24 +1901,31 @@ def is_context_parallel_supported():
18641901
else:
18651902
if world_size > 1:
18661903
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
1867-
if self.unified_checkpoint:
1904+
if self.save_checkpoint_format in [
1905+
"unified_checkpoint",
1906+
"flex_checkpoint",
1907+
] or self.load_checkpoint_format in ["unified_checkpoint", "flex_checkpoint"]:
18681908
# DP use hybrid group
18691909
strategy = fleet.DistributedStrategy()
18701910
fleet.init(is_collective=True, strategy=strategy)
18711911
else:
18721912
paddle.distributed.init_parallel_env()
18731913

18741914
if (
1875-
self.unified_checkpoint
1915+
(
1916+
self.save_checkpoint_format == "unified_checkpoint"
1917+
or self.load_checkpoint_format == "unified_checkpoint"
1918+
)
18761919
and self.sharding_parallel_degree > 0
18771920
and ShardingOption.FULL_SHARD in self.sharding
18781921
):
18791922
logger.warning(
1880-
"Unified checkpoint currently do not support sharding stage3, set `unified_checkpoint` to False."
1923+
"Unified checkpoint currently do not support sharding stage3, disabling unified_checkpoint format."
18811924
)
1882-
self.unified_checkpoint = False
1925+
self.save_checkpoint_format = None
1926+
self.load_checkpoint_format = None
18831927

1884-
if self.unified_checkpoint:
1928+
if self.save_checkpoint_format == "unified_checkpoint" or self.load_checkpoint_format == "unified_checkpoint":
18851929
unified_checkpoint_config = set(self.unified_checkpoint_config.split(" "))
18861930
if sys.platform.startswith("win") and "async_save" in self.unified_checkpoint_config:
18871931
raise ValueError("Currently do not support asynchronous saving for Windows system!")
@@ -2134,6 +2178,30 @@ def _post_init_parallel_degree(self):
21342178
if self.use_hybrid_parallel and self.enable_auto_parallel:
21352179
self.use_hybrid_parallel = False
21362180

2181+
def _post_init_save_checkpoint_format(self):
2182+
if self.save_checkpoint_format:
2183+
valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"]
2184+
assert (
2185+
self.save_checkpoint_format in valid_modes
2186+
), f"Invalid save_checkpoint_format: {self.save_checkpoint_format}, Only these formats are allowed: {valid_modes}."
2187+
else:
2188+
if self.unified_checkpoint:
2189+
self.save_checkpoint_format = "unified_checkpoint"
2190+
elif self.save_sharded_model:
2191+
self.save_checkpoint_format = "sharding_io"
2192+
2193+
def _post_init_load_checkpoint_format(self):
2194+
if self.load_checkpoint_format:
2195+
valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"]
2196+
assert (
2197+
self.load_checkpoint_format in valid_modes
2198+
), f"Invalid load_checkpoint_format: {self.load_checkpoint_format}, Only these formats are allowed: {valid_modes}."
2199+
else:
2200+
if self.unified_checkpoint:
2201+
self.load_checkpoint_format = "unified_checkpoint"
2202+
elif self.load_sharded_model:
2203+
self.load_checkpoint_format = "sharding_io"
2204+
21372205
def add_moe_comm_group(self):
21382206
hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs
21392207
hcg = fleet.get_hybrid_communicate_group()
@@ -2462,6 +2530,8 @@ def should_save_model_state(self):
24622530
return True
24632531
elif self.enable_auto_parallel:
24642532
return True
2533+
elif self.save_checkpoint_format == "flex_checkpoint":
2534+
return False
24652535
elif self.use_hybrid_parallel:
24662536
# save on dataset rank 0
24672537
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
@@ -2480,14 +2550,16 @@ def should_save_sharding_stage1_model(self):
24802550
if self.enable_auto_parallel:
24812551
return False
24822552
return (
2483-
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model
2553+
ShardingOption.SHARD_OP in self.sharding
2554+
and self.sharding_parallel_degree > 1
2555+
and self.save_checkpoint_format == "sharding_io"
24842556
)
24852557

24862558
@property
24872559
def should_load_sharding_stage1_model(self):
24882560
if self.enable_auto_parallel:
24892561
return False
2490-
return self.load_sharded_model
2562+
return self.load_checkpoint_format == "sharding_io"
24912563

24922564
@property
24932565
def should_load_dataset(self):

0 commit comments

Comments
 (0)