Skip to content

Commit 6d8fa31

Browse files
Wennie396AlAuAu
authored andcommitted
Add chunk offload optimizer (PaddlePaddle#11084)
* add chunk offload optimizer * fix get offload_opt_buffer_size arg
1 parent d8bea8c commit 6d8fa31

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,16 @@ class TrainingArguments:
614614
)
615615
},
616616
)
617+
sharding_offload_opt_buffersize_GB: int = field(
618+
default=-1,
619+
metadata={
620+
"help": (
621+
"Set the size of the optimizer offload buffer when need_hack_offload_optimizer() is True. This option only takes effect when "
622+
"use DygraphShardingOptimizerV2. The default value is -1, which means that all of the optimizer states will be offloaded. Only "
623+
"works when export HACK_OFFLOAD_OPTIMIZER=1. "
624+
)
625+
},
626+
)
617627

618628
save_sharded_model: bool = field(
619629
default=False,
@@ -1531,6 +1541,11 @@ def is_context_parallel_supported():
15311541
self.sharding_comm_buffer_size_MB
15321542
)
15331543

1544+
if hasattr(strategy.hybrid_configs["sharding_configs"], "offload_opt_buffer_size"):
1545+
strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int(
1546+
self.sharding_offload_opt_buffersize_GB
1547+
)
1548+
15341549
if "split_param" in sharding_parallel_config:
15351550
strategy.hybrid_configs["sharding_configs"].split_param = True
15361551
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"
@@ -1631,6 +1646,7 @@ def is_context_parallel_supported():
16311646
self.sharding_parallel_degree
16321647
* self.tensor_parallel_degree
16331648
* self.sep_parallel_degree
1649+
* self.context_parallel_degree
16341650
* self.pipeline_parallel_degree
16351651
)
16361652

paddlenlp/trainer/utils/offload_optimizer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def new_opt_op(*args):
5858
reload(arg)
5959

6060
ret = origin_op(*args)
61-
61+
is_offload_opt = getattr(args[0], "is_offload_opt", False)
6262
for i, arg in enumerate(args):
63-
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
63+
if (
64+
i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt
65+
): # do not offload parameter and gradient
6466
offload(arg)
6567
return ret
6668

@@ -74,7 +76,11 @@ def new_insert_sync(self, sync_var, *args, **kwargs):
7476
origin_place = sync_var.place
7577
reload(sync_var)
7678
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
77-
new_sync_var = to_device(sync_var, origin_place)
79+
is_offload_opt = getattr(sync_var, "is_offload_opt", False)
80+
if is_offload_opt:
81+
new_sync_var = to_device(sync_var, origin_place)
82+
else:
83+
new_sync_var = sync_var
7884
assert new_sync_var is sync_var, "to_device must be inplace operation"
7985
return ret
8086

0 commit comments

Comments
 (0)