Skip to content

llama 自动并行 pp组网 #10626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions llm/auto_parallel/llama/run_llama2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# just for debug

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama3_dp2pp4sd2"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"

export SOT_LOG_LEVEL=4
export PYTHONPATH=../../../:$PYTHONPATH

#ulimit -c unlimited
# export GLOG_v=6
export NCCL_DEBUG=INFO

# export FLAGS_call_stack_level=3
# export FLAGS_use_cuda_managed_memory=true

# export FLAGS_embedding_deterministic=1
# export FLAGS_cudnn_deterministic=1
# export NVIDIA_TF32_OVERRIDE=0
rm -rf core.*
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
./run_pretrain_auto.py \
../../../tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json

4 changes: 3 additions & 1 deletion llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM3DAuto,
LlamaForCausalLM3DAutoPP,
LlamaForCausalLMNet,
LlamaPretrainingCriterion3DAuto,
LlamaPretrainingCriterionNet,
)
from paddlenlp.utils.log import logger
from paddle.distributed.auto_parallel.pipelining.schedules import ScheduleGPipe

MODEL_CLASSES = {
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
"llama": (LlamaConfig, LlamaForCausalLM3DAutoPP, LlamaPretrainingCriterion3DAuto),
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
}

Expand Down
160 changes: 121 additions & 39 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,74 @@
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None

from paddle.distributed.auto_parallel.pipelining.schedules import ScheduleGPipe, Schedule1F1B
from paddle.distributed.auto_parallel.pipelining.stage import PipelineStage

Check warning on line 62 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L61-L62

Added lines #L61 - L62 were not covered by tests


MODEL_NAME = "model"
OPTIMIZER_NAME = "optimizer"
DIST_CKPT_PATH = "dist_ckpt"
DIST_MODEL_PATH = "dist_model"
FREE_SVAE_LOAD_KEY_PATTERNS = ["learning_rate_", "gradient_merge_", "@GRAD@MERG", "eager_tmp"]

is_split_model = False
local_stage = None

Check warning on line 72 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L71-L72

Added lines #L71 - L72 were not covered by tests

def manual_model_split(model,stage_idx,group):

Check warning on line 74 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L74

Added line #L74 was not covered by tests
global is_split_model
global local_stage

if is_split_model:
return local_stage
if stage_idx == 0:
for i in range(10):
del model.layers[10]

Check warning on line 82 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L78-L82

Added lines #L78 - L82 were not covered by tests

def forward0(

Check warning on line 84 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L84

Added line #L84 was not covered by tests
self,
input_ids=None,
labels=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
use_cache=False,
past_key_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
outputs = tuple([input_ids, attention_mask, position_ids])
outputs = tuple([input_ids, attention_mask, position_ids])

Check warning on line 98 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L97-L98

Added lines #L97 - L98 were not covered by tests

# decoder layers
for idx, (decoder_layer) in enumerate(self.layers):
outputs = decoder_layer(outputs)
return outputs
setattr(model.__class__, "forward", forward0)

Check warning on line 104 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L101-L104

Added lines #L101 - L104 were not covered by tests

elif stage_idx == 1:
for i in range(10):
del model.layers[0]
def forward1(self, *args):
outputs = args

Check warning on line 110 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L106-L110

Added lines #L106 - L110 were not covered by tests
# decoder layers
for idx, (decoder_layer) in enumerate(self.layers):
outputs = decoder_layer(outputs)
return outputs
setattr(model.__class__, "forward", forward1)

Check warning on line 115 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L112-L115

Added lines #L112 - L115 were not covered by tests
else:
raise ValueError("Invalid stage index.")

Check warning on line 117 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L117

Added line #L117 was not covered by tests

stage = PipelineStage(

Check warning on line 119 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L119

Added line #L119 was not covered by tests
model,
stage_idx,
2,
group=group
)
is_split_model = True
local_stage = stage
return stage

Check warning on line 127 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L125-L127

Added lines #L125 - L127 were not covered by tests

class AutoTrainer(Trainer):
def __init__(self, *args, **kwargs):
Expand All @@ -88,7 +149,7 @@
), "if use AutoTrainer.parallel_model , auto_dist_config obtained from parallel_model should be passed to AutoTrainer "
self.auto_dist_config = kwargs.pop("auto_dist_config")
model = kwargs["model"]
for param in model.parameters():
for name, param in model.named_parameters():

Check warning on line 152 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L152

Added line #L152 was not covered by tests
# NOTE(zhangwl):in pipeline mode , param my be initialized before while delte init_func ,but param is still not is_initialized
if not param._is_initialized() and param._init_func is not None:
param.initialize()
Expand All @@ -98,7 +159,7 @@
assert self.args.enable_auto_parallel

self.global_mesh = fleet.auto.get_mesh()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()

Check warning on line 162 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L162

Added line #L162 was not covered by tests
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]

@classmethod
Expand Down Expand Up @@ -670,50 +731,71 @@
labels = inputs["generator_labels"]
else:
labels = None
def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh
rank = dist.get_rank()
if rank == 0 or rank == 1 or rank == 2 or rank == 3:
stage = manual_model_split(model, 0, self.comm_group_in_pp)

Check warning on line 741 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L734-L741

Added lines #L734 - L741 were not covered by tests
else:
stage = manual_model_split(model, 1, self.comm_group_in_pp)

Check warning on line 743 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L743

Added line #L743 was not covered by tests

outputs = model(**inputs)

if self.criterion is not None:
schedule = Schedule1F1B(stage, n_microbatches = 2, loss_fn=self.criterion)

Check warning on line 745 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L745

Added line #L745 was not covered by tests

def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return list(value)
return [value]

criterion_inputs = to_list(outputs)
criterion_labels = to_list(labels)
loss = self.criterion(*(criterion_inputs + criterion_labels))
outputs = (loss, outputs)

# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs
if isinstance(outputs, dict):
loss = outputs["loss"]
elif isinstance(outputs, tuple):
loss = outputs[0]
if rank == 0 or rank == 1 or rank == 2 or rank == 3:
inputs["input_ids"] = dist.reshard(inputs["input_ids"], get_mesh(0), [dist.Replicate(), dist.Replicate()])
schedule.step(**inputs)

Check warning on line 749 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L747-L749

Added lines #L747 - L749 were not covered by tests
else:
loss = outputs

return (loss, outputs) if return_outputs else loss
labels = dist.reshard(labels, get_mesh(1), [dist.Replicate(), dist.Replicate()])
losses = []
schedule.step(target=labels, losses = losses)
print("losses: ", losses)
return 0

Check warning on line 755 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L751-L755

Added lines #L751 - L755 were not covered by tests
# outputs = model(**inputs)

# if self.criterion is not None:

# def to_list(value):
# if value is None:
# return value
# if isinstance(value, (list, tuple)):
# return list(value)
# return [value]

# criterion_inputs = to_list(outputs)
# criterion_labels = to_list(labels)
# loss = self.criterion(*(criterion_inputs + criterion_labels))
# outputs = (loss, outputs)

# # Save past state if it exists
# # TODO: this needs to be fixed and made cleaner later.
# if self.args.past_index >= 0:
# self._past = outputs[self.args.past_index]

# # We don't use .loss here since the model may return tuples instead of ModelOutput.
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs
# if isinstance(outputs, dict):
# loss = outputs["loss"]
# elif isinstance(outputs, tuple):
# loss = outputs[0]
# else:
# loss = outputs

# return (loss, outputs) if return_outputs else loss

def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():
loss = loss / self.args.gradient_accumulation_steps

if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
# if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():
# loss = loss / self.args.gradient_accumulation_steps

# if self.do_grad_scaling:
# self.scaler.scale(loss).backward()
# else:
# loss.backward()

return loss

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .configuration import *
from .modeling import *
from .modeling_auto import *
from .modeling_auto_pp import *
from .modeling_network import *
from .modeling_pp import *
from .tokenizer import *
Expand Down
Loading
Loading