Skip to content

Commit 011ae71

Browse files
authored
[Unified Checkpoint] Fix split param loading directly when using ignore_merge_optimizer (#9935)
* fix * refine code * update typename
1 parent b6956a3 commit 011ae71

File tree

2 files changed

+187
-50
lines changed

2 files changed

+187
-50
lines changed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 166 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,27 @@
1515

1616
import gc
1717
import os
18+
import re
1819
from itertools import chain
1920

2021
import paddle
2122
import paddle.distributed as dist
2223
from paddle.distributed import fleet
24+
from safetensors import safe_open
2325
from tqdm.auto import tqdm
2426

2527
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
26-
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
28+
from paddlenlp.transformers.model_utils import (
29+
_add_variant,
30+
load_state_dict,
31+
unwrap_model,
32+
)
33+
from paddlenlp.transformers.utils import device_guard
2734
from paddlenlp.utils.env import (
2835
SAFE_MASTER_WEIGHTS_INDEX_NAME,
36+
SAFE_MASTER_WEIGHTS_NAME,
2937
SAFE_OPTIMIZER_INDEX_NAME,
38+
SAFE_OPTIMIZER_NAME,
3039
)
3140
from paddlenlp.utils.nested import nested_copy
3241

@@ -175,6 +184,49 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175184
return optim_state_dict, master_weights
176185

177186

187+
def get_params_info(comm_buffer_list):
188+
expected_keys = []
189+
param_slice_info = {}
190+
param_shape_info = {}
191+
192+
for buffer in comm_buffer_list:
193+
for key in buffer._sharding_param_grad_view.keys():
194+
begin = buffer._sharding_param_grad_view[key]._param_begin
195+
end = buffer._sharding_param_grad_view[key]._param_end
196+
if end > begin:
197+
expected_keys.append(key)
198+
shape = buffer._sharding_param_grad_view[key]._param.shape
199+
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
200+
index = buffer._sharding_param_grad_view[key]._index
201+
padded_size = buffer._sharding_param_grad_view[key]._padded_size
202+
param_slice_info[key] = (begin, end)
203+
param_shape_info[key] = (shape, numel, index, padded_size)
204+
return expected_keys, param_slice_info, param_shape_info
205+
206+
207+
def reshape_params(state_dict, struct2static_name_mappings, param_shape_info, param_slice_info):
208+
"""Reshape params to 1-D tensors"""
209+
for key in list(state_dict.keys()):
210+
key_name = key.split("/")[0]
211+
static_name = struct2static_name_mappings.get(key_name, None)
212+
if int(state_dict[key].numel()) > 1:
213+
begin, end = param_slice_info[static_name]
214+
_, numel, index, padded_size = param_shape_info[static_name]
215+
state_dict[key] = state_dict[key].reshape([-1])
216+
state_dict[key] = state_dict[key][begin - index : end - index]
217+
218+
padding_start = max(begin, index + numel)
219+
padding_end = min(end, index + padded_size)
220+
if padding_start < padding_end:
221+
state_dict[key] = paddle.concat(
222+
(
223+
state_dict[key],
224+
paddle.zeros([padding_end - padding_start], dtype=state_dict[key].dtype),
225+
)
226+
)
227+
return state_dict
228+
229+
178230
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
179231
returned_optim_state_dict = nested_copy(optimizer.state_dict())
180232

@@ -196,28 +248,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196248
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
197249
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
198250

199-
expected_keys = []
200-
param_slice_info = {}
201-
param_shape_info = {}
202-
203251
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
204252
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
205253
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
206254
model = unwrap_model(model)
207255

208-
for buffer in comm_buffer_list:
209-
for key in buffer._sharding_param_grad_view.keys():
210-
begin = buffer._sharding_param_grad_view[key]._param_begin
211-
end = buffer._sharding_param_grad_view[key]._param_end
212-
if end > begin:
213-
expected_keys.append(key)
214-
shape = buffer._sharding_param_grad_view[key]._param.shape
215-
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
216-
index = buffer._sharding_param_grad_view[key]._index
217-
padded_size = buffer._sharding_param_grad_view[key]._padded_size
218-
param_slice_info[key] = (begin, end)
219-
param_shape_info[key] = (shape, numel, index, padded_size)
220-
256+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
221257
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
222258
expected_keys_optim = []
223259
for key in expected_keys:
@@ -285,25 +321,10 @@ def load_resolved_archive_file(
285321
)
286322

287323
# need to split param for different sharding rank, maybe need to deal with oom issue.
324+
reshape_params(state_dict_optim, struct2static_name_mappings, param_shape_info, param_slice_info)
288325
for key in list(state_dict_optim.keys()):
289326
key_name = key.split("/")
290327
static_name = struct2static_name_mappings.get(key_name[0], None)
291-
292-
if int(state_dict_optim[key].numel()) > 1:
293-
begin, end = param_slice_info[static_name]
294-
shape, numel, index, padded_size = param_shape_info[static_name]
295-
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
296-
state_dict_optim[key] = state_dict_optim[key][begin - index : end - index]
297-
298-
padding_start = max(begin, index + numel)
299-
padding_end = min(end, index + padded_size)
300-
if padding_start < padding_end:
301-
state_dict_optim[key] = paddle.concat(
302-
(
303-
state_dict_optim[key],
304-
paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype),
305-
)
306-
)
307328
if has_master_weights:
308329
if model_state_dict[key_name[0]].dtype != paddle.float32:
309330
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
@@ -325,24 +346,10 @@ def load_resolved_archive_file(
325346
expected_keys,
326347
is_master_weights=True,
327348
)
349+
reshape_params(state_dict_master_weight, struct2static_name_mappings, param_shape_info, param_slice_info)
328350

329351
for key in list(state_dict_master_weight.keys()):
330352
static_name = struct2static_name_mappings.get(key, None)
331-
if int(state_dict_master_weight[key].numel()) > 1:
332-
begin, end = param_slice_info[static_name]
333-
shape, numel, index, padded_size = param_shape_info[static_name]
334-
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
335-
state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index]
336-
337-
padding_start = max(begin, index + numel)
338-
padding_end = min(end, index + padded_size)
339-
if padding_start < padding_end:
340-
state_dict_master_weight[key] = paddle.concat(
341-
(
342-
state_dict_master_weight[key],
343-
paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype),
344-
)
345-
)
346353
state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to(
347354
paddle.framework._current_expected_place(), False
348355
)
@@ -357,3 +364,113 @@ def load_resolved_archive_file(
357364
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
358365

359366
return returned_optim_state_dict
367+
368+
369+
def load_non_merge_optimizer_with_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
370+
returned_optim_state_dict = nested_copy(optimizer.state_dict())
371+
372+
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, args.optimizer_name_suffix)
373+
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, args.optimizer_name_suffix)
374+
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
375+
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)
376+
377+
# no quantization & no master weight represent O1 AMP strategy.
378+
is_amp_o1 = args.fp16_opt_level == "O1"
379+
380+
model_state_dict = get_expected_state_dict(model)
381+
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
382+
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
383+
384+
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
385+
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
386+
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
387+
388+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
389+
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
390+
expected_keys_optim = []
391+
sharding_typename_set, typename_set = [], []
392+
with safe_open(optimizer_path, framework="numpy") as f:
393+
optim_keys = f.keys()
394+
for key in optim_keys:
395+
_, typename = key.split("/")
396+
typename_set.append(typename)
397+
398+
# To avoid incomplete typename in some shard files, communication is performed.
399+
hcg = fleet.get_hybrid_communicate_group()
400+
sharding_group = hcg.get_sharding_parallel_group()
401+
dist.all_gather_object(sharding_typename_set, typename_set, sharding_group)
402+
typename_set = set(chain(*sharding_typename_set))
403+
for key in expected_keys:
404+
for typename in typename_set:
405+
expected_keys_optim.append(f"{key}/{typename}")
406+
expected_keys_optim = set(expected_keys_optim)
407+
408+
optimizer_state_dict = load_state_dict(
409+
optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage
410+
)
411+
master_weights = {}
412+
# normal AMP O2
413+
if not is_amp_o1 and os.path.isfile(master_weights_path):
414+
master_weights = load_state_dict(master_weights_path, None, None, device="expected")
415+
416+
def get_unfound_params(unfound_keys, state_dict, is_optimizer=True):
417+
if len(unfound_keys) > 0:
418+
backup_files = []
419+
files = os.listdir(resume_from_checkpoint)
420+
name = optimizer_name if is_optimizer else master_weights_name
421+
name_without_shard = re.sub(r"_?shard\d+_?", "", name)
422+
name_ = "optimizer" if is_optimizer else "master_weights"
423+
for f in files:
424+
if f.startswith(name_) and f.endswith("safetensors") and f != name:
425+
if re.sub(r"_?shard\d+_?", "", f) == name_without_shard:
426+
backup_files.append(f)
427+
for f in backup_files:
428+
new_path = os.path.join(resume_from_checkpoint, f)
429+
with safe_open(new_path, framework="numpy") as fin:
430+
keys = fin.keys()
431+
for key in unfound_keys:
432+
if key in keys:
433+
tensor = fin.get_tensor(key)
434+
with device_guard():
435+
tensor = paddle.Tensor(tensor, zero_copy=True)
436+
state_dict[key] = tensor._copy_to(paddle.framework._current_expected_place(), False)
437+
438+
# Get other optimizer paramsters which maybe in other shard files.
439+
unfound_keys = expected_keys_optim - optimizer_state_dict.keys()
440+
get_unfound_params(unfound_keys, optimizer_state_dict, True)
441+
442+
# Get other master weight parameters which maybe in other shard files.
443+
if master_weights != {}:
444+
unfound_keys = expected_keys - master_weights.keys()
445+
get_unfound_params(unfound_keys, master_weights, False)
446+
reshape_params(optimizer_state_dict, struct2static_name_mappings, param_shape_info, param_slice_info)
447+
448+
# rename and move to paddle.Tensor
449+
for key in list(optimizer_state_dict.keys()):
450+
key_name = key.split("/")
451+
model_weight_key = key_name[0]
452+
static_name = struct2static_name_mappings[key_name[0]]
453+
if not is_amp_o1:
454+
if model_state_dict[key_name[0]].dtype != paddle.float32:
455+
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
456+
else:
457+
key_name = "_".join([static_name, key_name[1]])
458+
else:
459+
key_name = "_".join([static_name, key_name[1]])
460+
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)
461+
returned_optim_state_dict[key_name].name = key_name
462+
463+
# master weight cast (only in AMP O2 + remove_master_weight)
464+
if not is_amp_o1 and not os.path.isfile(master_weights_path):
465+
master_weights[model_weight_key] = paddle.cast(model_state_dict[model_weight_key], dtype=paddle.float32)
466+
467+
if not is_amp_o1:
468+
reshape_params(master_weights, struct2static_name_mappings, param_shape_info, param_slice_info)
469+
470+
returned_optim_state_dict["master_weights"] = {}
471+
for key in list(master_weights.keys()):
472+
static_name = struct2static_name_mappings[key]
473+
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)
474+
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
475+
476+
return returned_optim_state_dict

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
save_single_card_checkpoint,
6363
save_single_card_optimizer,
6464
)
65-
from .sharding_split_param_utils import gather_splited_param_for_optimizer
65+
from .sharding_split_param_utils import (
66+
gather_splited_param_for_optimizer,
67+
load_non_merge_optimizer_with_split_param,
68+
)
6669
from .utils import (
6770
FP32_MASTER,
6871
UnifiedCheckpointOption,
@@ -263,6 +266,23 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
263266
)
264267

265268
def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
269+
"""load non merge optimizer
270+
271+
Args:
272+
model (PretrainedModel): model used to get key mapping.
273+
optimizer (Optimizer): optimizer to load
274+
resume_from_checkpoint (str): path of the checkpoint to load
275+
ckpt_quant_stage (str): ckpt quant stage
276+
277+
Returns:
278+
dict: optimizer state dict
279+
"""
280+
281+
if is_sharding_split_param_mode(self.args):
282+
return load_non_merge_optimizer_with_split_param(
283+
self.args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
284+
)
285+
266286
# init and get optimizer LR_Scheduler
267287
returned_optim_state_dict = nested_copy(optimizer.state_dict())
268288

0 commit comments

Comments
 (0)