15
15
16
16
import gc
17
17
import os
18
+ import re
18
19
from itertools import chain
19
20
20
21
import paddle
21
22
import paddle .distributed as dist
22
23
from paddle .distributed import fleet
24
+ from safetensors import safe_open
23
25
from tqdm .auto import tqdm
24
26
25
27
from paddlenlp .peft import LoRAModel , PrefixModelForCausalLM
26
- from paddlenlp .transformers .model_utils import load_state_dict , unwrap_model
28
+ from paddlenlp .transformers .model_utils import (
29
+ _add_variant ,
30
+ load_state_dict ,
31
+ unwrap_model ,
32
+ )
33
+ from paddlenlp .transformers .utils import device_guard
27
34
from paddlenlp .utils .env import (
28
35
SAFE_MASTER_WEIGHTS_INDEX_NAME ,
36
+ SAFE_MASTER_WEIGHTS_NAME ,
29
37
SAFE_OPTIMIZER_INDEX_NAME ,
38
+ SAFE_OPTIMIZER_NAME ,
30
39
)
31
40
from paddlenlp .utils .nested import nested_copy
32
41
@@ -175,6 +184,49 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175
184
return optim_state_dict , master_weights
176
185
177
186
187
+ def get_params_info (comm_buffer_list ):
188
+ expected_keys = []
189
+ param_slice_info = {}
190
+ param_shape_info = {}
191
+
192
+ for buffer in comm_buffer_list :
193
+ for key in buffer ._sharding_param_grad_view .keys ():
194
+ begin = buffer ._sharding_param_grad_view [key ]._param_begin
195
+ end = buffer ._sharding_param_grad_view [key ]._param_end
196
+ if end > begin :
197
+ expected_keys .append (key )
198
+ shape = buffer ._sharding_param_grad_view [key ]._param .shape
199
+ numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
200
+ index = buffer ._sharding_param_grad_view [key ]._index
201
+ padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
202
+ param_slice_info [key ] = (begin , end )
203
+ param_shape_info [key ] = (shape , numel , index , padded_size )
204
+ return expected_keys , param_slice_info , param_shape_info
205
+
206
+
207
+ def reshape_params (state_dict , struct2static_name_mappings , param_shape_info , param_slice_info ):
208
+ """Reshape params to 1-D tensors"""
209
+ for key in list (state_dict .keys ()):
210
+ key_name = key .split ("/" )[0 ]
211
+ static_name = struct2static_name_mappings .get (key_name , None )
212
+ if int (state_dict [key ].numel ()) > 1 :
213
+ begin , end = param_slice_info [static_name ]
214
+ _ , numel , index , padded_size = param_shape_info [static_name ]
215
+ state_dict [key ] = state_dict [key ].reshape ([- 1 ])
216
+ state_dict [key ] = state_dict [key ][begin - index : end - index ]
217
+
218
+ padding_start = max (begin , index + numel )
219
+ padding_end = min (end , index + padded_size )
220
+ if padding_start < padding_end :
221
+ state_dict [key ] = paddle .concat (
222
+ (
223
+ state_dict [key ],
224
+ paddle .zeros ([padding_end - padding_start ], dtype = state_dict [key ].dtype ),
225
+ )
226
+ )
227
+ return state_dict
228
+
229
+
178
230
def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
179
231
returned_optim_state_dict = nested_copy (optimizer .state_dict ())
180
232
@@ -196,28 +248,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196
248
static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
197
249
struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
198
250
199
- expected_keys = []
200
- param_slice_info = {}
201
- param_shape_info = {}
202
-
203
251
comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
204
252
if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
205
253
comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
206
254
model = unwrap_model (model )
207
255
208
- for buffer in comm_buffer_list :
209
- for key in buffer ._sharding_param_grad_view .keys ():
210
- begin = buffer ._sharding_param_grad_view [key ]._param_begin
211
- end = buffer ._sharding_param_grad_view [key ]._param_end
212
- if end > begin :
213
- expected_keys .append (key )
214
- shape = buffer ._sharding_param_grad_view [key ]._param .shape
215
- numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
216
- index = buffer ._sharding_param_grad_view [key ]._index
217
- padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
218
- param_slice_info [key ] = (begin , end )
219
- param_shape_info [key ] = (shape , numel , index , padded_size )
220
-
256
+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
221
257
expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
222
258
expected_keys_optim = []
223
259
for key in expected_keys :
@@ -285,25 +321,10 @@ def load_resolved_archive_file(
285
321
)
286
322
287
323
# need to split param for different sharding rank, maybe need to deal with oom issue.
324
+ reshape_params (state_dict_optim , struct2static_name_mappings , param_shape_info , param_slice_info )
288
325
for key in list (state_dict_optim .keys ()):
289
326
key_name = key .split ("/" )
290
327
static_name = struct2static_name_mappings .get (key_name [0 ], None )
291
-
292
- if int (state_dict_optim [key ].numel ()) > 1 :
293
- begin , end = param_slice_info [static_name ]
294
- shape , numel , index , padded_size = param_shape_info [static_name ]
295
- state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
296
- state_dict_optim [key ] = state_dict_optim [key ][begin - index : end - index ]
297
-
298
- padding_start = max (begin , index + numel )
299
- padding_end = min (end , index + padded_size )
300
- if padding_start < padding_end :
301
- state_dict_optim [key ] = paddle .concat (
302
- (
303
- state_dict_optim [key ],
304
- paddle .zeros ([padding_end - padding_start ], dtype = state_dict_optim [key ].dtype ),
305
- )
306
- )
307
328
if has_master_weights :
308
329
if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
309
330
key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
@@ -325,24 +346,10 @@ def load_resolved_archive_file(
325
346
expected_keys ,
326
347
is_master_weights = True ,
327
348
)
349
+ reshape_params (state_dict_master_weight , struct2static_name_mappings , param_shape_info , param_slice_info )
328
350
329
351
for key in list (state_dict_master_weight .keys ()):
330
352
static_name = struct2static_name_mappings .get (key , None )
331
- if int (state_dict_master_weight [key ].numel ()) > 1 :
332
- begin , end = param_slice_info [static_name ]
333
- shape , numel , index , padded_size = param_shape_info [static_name ]
334
- state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
335
- state_dict_master_weight [key ] = state_dict_master_weight [key ][begin - index : end - index ]
336
-
337
- padding_start = max (begin , index + numel )
338
- padding_end = min (end , index + padded_size )
339
- if padding_start < padding_end :
340
- state_dict_master_weight [key ] = paddle .concat (
341
- (
342
- state_dict_master_weight [key ],
343
- paddle .zeros ([padding_end - padding_start ], dtype = state_dict_master_weight [key ].dtype ),
344
- )
345
- )
346
353
state_dict_master_weight [key ] = state_dict_master_weight [key ]._copy_to (
347
354
paddle .framework ._current_expected_place (), False
348
355
)
@@ -357,3 +364,113 @@ def load_resolved_archive_file(
357
364
returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
358
365
359
366
return returned_optim_state_dict
367
+
368
+
369
+ def load_non_merge_optimizer_with_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
370
+ returned_optim_state_dict = nested_copy (optimizer .state_dict ())
371
+
372
+ optimizer_name = _add_variant (SAFE_OPTIMIZER_NAME , args .optimizer_name_suffix )
373
+ master_weights_name = _add_variant (SAFE_MASTER_WEIGHTS_NAME , args .optimizer_name_suffix )
374
+ optimizer_path = os .path .join (resume_from_checkpoint , optimizer_name )
375
+ master_weights_path = os .path .join (resume_from_checkpoint , master_weights_name )
376
+
377
+ # no quantization & no master weight represent O1 AMP strategy.
378
+ is_amp_o1 = args .fp16_opt_level == "O1"
379
+
380
+ model_state_dict = get_expected_state_dict (model )
381
+ static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
382
+ struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
383
+
384
+ comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
385
+ if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
386
+ comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
387
+
388
+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
389
+ expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
390
+ expected_keys_optim = []
391
+ sharding_typename_set , typename_set = [], []
392
+ with safe_open (optimizer_path , framework = "numpy" ) as f :
393
+ optim_keys = f .keys ()
394
+ for key in optim_keys :
395
+ _ , typename = key .split ("/" )
396
+ typename_set .append (typename )
397
+
398
+ # To avoid incomplete typename in some shard files, communication is performed.
399
+ hcg = fleet .get_hybrid_communicate_group ()
400
+ sharding_group = hcg .get_sharding_parallel_group ()
401
+ dist .all_gather_object (sharding_typename_set , typename_set , sharding_group )
402
+ typename_set = set (chain (* sharding_typename_set ))
403
+ for key in expected_keys :
404
+ for typename in typename_set :
405
+ expected_keys_optim .append (f"{ key } /{ typename } " )
406
+ expected_keys_optim = set (expected_keys_optim )
407
+
408
+ optimizer_state_dict = load_state_dict (
409
+ optimizer_path , None , None , device = "expected" , ckpt_quant_stage = ckpt_quant_stage
410
+ )
411
+ master_weights = {}
412
+ # normal AMP O2
413
+ if not is_amp_o1 and os .path .isfile (master_weights_path ):
414
+ master_weights = load_state_dict (master_weights_path , None , None , device = "expected" )
415
+
416
+ def get_unfound_params (unfound_keys , state_dict , is_optimizer = True ):
417
+ if len (unfound_keys ) > 0 :
418
+ backup_files = []
419
+ files = os .listdir (resume_from_checkpoint )
420
+ name = optimizer_name if is_optimizer else master_weights_name
421
+ name_without_shard = re .sub (r"_?shard\d+_?" , "" , name )
422
+ name_ = "optimizer" if is_optimizer else "master_weights"
423
+ for f in files :
424
+ if f .startswith (name_ ) and f .endswith ("safetensors" ) and f != name :
425
+ if re .sub (r"_?shard\d+_?" , "" , f ) == name_without_shard :
426
+ backup_files .append (f )
427
+ for f in backup_files :
428
+ new_path = os .path .join (resume_from_checkpoint , f )
429
+ with safe_open (new_path , framework = "numpy" ) as fin :
430
+ keys = fin .keys ()
431
+ for key in unfound_keys :
432
+ if key in keys :
433
+ tensor = fin .get_tensor (key )
434
+ with device_guard ():
435
+ tensor = paddle .Tensor (tensor , zero_copy = True )
436
+ state_dict [key ] = tensor ._copy_to (paddle .framework ._current_expected_place (), False )
437
+
438
+ # Get other optimizer paramsters which maybe in other shard files.
439
+ unfound_keys = expected_keys_optim - optimizer_state_dict .keys ()
440
+ get_unfound_params (unfound_keys , optimizer_state_dict , True )
441
+
442
+ # Get other master weight parameters which maybe in other shard files.
443
+ if master_weights != {}:
444
+ unfound_keys = expected_keys - master_weights .keys ()
445
+ get_unfound_params (unfound_keys , master_weights , False )
446
+ reshape_params (optimizer_state_dict , struct2static_name_mappings , param_shape_info , param_slice_info )
447
+
448
+ # rename and move to paddle.Tensor
449
+ for key in list (optimizer_state_dict .keys ()):
450
+ key_name = key .split ("/" )
451
+ model_weight_key = key_name [0 ]
452
+ static_name = struct2static_name_mappings [key_name [0 ]]
453
+ if not is_amp_o1 :
454
+ if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
455
+ key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
456
+ else :
457
+ key_name = "_" .join ([static_name , key_name [1 ]])
458
+ else :
459
+ key_name = "_" .join ([static_name , key_name [1 ]])
460
+ returned_optim_state_dict [key_name ] = optimizer_state_dict .pop (key )
461
+ returned_optim_state_dict [key_name ].name = key_name
462
+
463
+ # master weight cast (only in AMP O2 + remove_master_weight)
464
+ if not is_amp_o1 and not os .path .isfile (master_weights_path ):
465
+ master_weights [model_weight_key ] = paddle .cast (model_state_dict [model_weight_key ], dtype = paddle .float32 )
466
+
467
+ if not is_amp_o1 :
468
+ reshape_params (master_weights , struct2static_name_mappings , param_shape_info , param_slice_info )
469
+
470
+ returned_optim_state_dict ["master_weights" ] = {}
471
+ for key in list (master_weights .keys ()):
472
+ static_name = struct2static_name_mappings [key ]
473
+ returned_optim_state_dict ["master_weights" ][static_name ] = master_weights .pop (key )
474
+ returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
475
+
476
+ return returned_optim_state_dict
0 commit comments