Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ repos:
language: python
files: \.(md|markdown|rst)$
pass_filenames: true

- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat.
entry: bash ./tools/codestyle/clang_format.sh -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
19 changes: 16 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,13 @@ def __init__(
self.optimizer_grouped_parameters = None
self.sharding_io = None
if self.args.should_save_sharding_stage1_model or self.args.should_load_sharding_stage1_model:
self.sharding_io = ShardingIO(self.args, self.model, self.optimizer)
self.sharding_io = ShardingIO(
self.args,
self.model,
self.optimizer,
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
)

if self.args.unified_checkpoint:
self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args)

Expand Down Expand Up @@ -804,9 +810,16 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
if resume_from_checkpoint is not None:
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
else:
success, err_msg = True, None
if os.path.exists(path):
logger.info(f"ZCC EMA load from {path}")
self.zcc_manager.set_ema_state_dict(path)
if success:
logger.info(f"ZCC EMA load from {path}")
self.zcc_manager.set_ema_state_dict(path)
else:
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
else:
logger.info(f"ZCC EMA state dict not found, in: {path}")

Expand Down
39 changes: 30 additions & 9 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,11 @@ class TrainingArguments:
},
)

load_sharded_model_remap_parameter_name: bool = field(
default=False,
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
)

tensor_parallel_degree: int = field(
default=-1,
metadata={
Expand Down Expand Up @@ -2025,6 +2030,11 @@ def _post_init_parallel_degree(self):
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
)

if expert_parallel_degree > 1:
assert (
self.expert_tensor_parallel_degree <= 1
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"

assert not (
self.data_parallel_degree > 1 and expert_parallel_degree > 1
), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}."
Expand Down Expand Up @@ -2213,6 +2223,17 @@ def pipeline_parallel_rank(self):
else:
return 0

@property
def expert_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
if hasattr(hcg, "get_expert_parallel_rank"):
return max(hcg.get_expert_parallel_rank(), 0)
else:
return 0
else:
return 0

@property
def context_parallel_rank(self):
if self.use_hybrid_parallel:
Expand All @@ -2238,7 +2259,7 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_expert_parallel:
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
return "_".join(name)
else:
Expand All @@ -2254,7 +2275,7 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel:
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
return "_".join(name)

Expand All @@ -2263,7 +2284,9 @@ def weight_name_suffix(self):
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_parallel_degree=None):
if sharding_parallel_degree is None:
sharding_parallel_degree = self.sharding_parallel_degree
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -2273,12 +2296,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
pp_id = self.pipeline_parallel_rank
assert isinstance(pp_id, int)
name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
if sharding_parallel_degree > 1:
if shard_id is None:
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("shard", shard_id, sharding_parallel_degree))
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
Expand Down Expand Up @@ -2404,9 +2427,7 @@ def should_save_sharding_stage1_model(self):
def should_load_sharding_stage1_model(self):
if self.enable_auto_parallel:
return False
return (
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model
)
return self.load_sharded_model

@property
def should_load_dataset(self):
Expand Down
81 changes: 81 additions & 0 deletions paddlenlp/trainer/utils/offload_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import _C_ops
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
from paddle.optimizer import Optimizer

from .sharding_io import to_device


def offload(tensor):
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPinnedPlace()
else:
place = paddle.CPUPlace()

new_tensor = to_device(tensor, place)
assert new_tensor is tensor, "to_device must be inplace operation"


def reload(tensor):
new_tensor = to_device(tensor)
assert new_tensor is tensor, "to_device must be inplace operation"


def hack_offload_optimizer():
# Step 1: mock _add_accumulator
origin_add_accumulator = getattr(Optimizer, "_add_accumulator")

def new_add_accumulator(self, *args, **kwargs):
x = origin_add_accumulator(self, *args, **kwargs)
offload(x)
return x

setattr(Optimizer, "_add_accumulator", new_add_accumulator)

# Step 2: mock _C_ops.adamw_ and _C_ops.adamw
for name in ["adam_", "adamw_"]:
origin_op = getattr(_C_ops, name)

def new_opt_op(*args):
for arg in args:
if isinstance(arg, paddle.Tensor):
reload(arg)

ret = origin_op(*args)

for i, arg in enumerate(args):
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
offload(arg)
return ret

setattr(_C_ops, name, new_opt_op)

# Step 3: mock _insert_sync
opt_type = HybridParallelOptimizer
origin_insert_sync = getattr(opt_type, "_insert_sync")

def new_insert_sync(self, sync_var, *args, **kwargs):
origin_place = sync_var.place
reload(sync_var)
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
new_sync_var = to_device(sync_var, origin_place)
assert new_sync_var is sync_var, "to_device must be inplace operation"
return ret

setattr(opt_type, "_insert_sync", new_insert_sync)
8 changes: 8 additions & 0 deletions paddlenlp/trainer/utils/reshard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
SHARDING_STRATEGY_V2,
NodeModelState,
all_gather_state_dict,
convert_opt_name_to_tname,
get_moe_sharding_group,
get_param_sharding_group,
get_sharding_strategy,
is_sharding_opt,
merge_model_state,
merge_opt_state,
split_model_state,
split_opt_state,
split_structure_name_mapping,
)
Loading
Loading