Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM3DAuto,
LlamaForCausalLM3DAutoPP,
LlamaForCausalLMNet,
LlamaPretrainingCriterion3DAuto,
LlamaPretrainingCriterionNet,
Expand All @@ -49,6 +50,7 @@

MODEL_CLASSES = {
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
"llama_hybrid_pp": (LlamaConfig, LlamaForCausalLM3DAutoPP, LlamaPretrainingCriterion3DAuto),
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
}

Expand Down Expand Up @@ -86,13 +88,18 @@ class PreTrainingArguments(AutoTrainingArguments):
)
sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."})
virtual_pipeline_seg_method: str = field(
default="LlamaDecoderLayerAuto", metadata={"help": "The seg method of splitting pp layer for virtual pipeline."}
default="LlamaDecoderLayerAuto",
metadata={"help": "The seg method of splitting pp layer for virtual pipeline."},
)
# NOTE(gongenlei): new add autotuner_benchmark
autotuner_benchmark: bool = field(
default=False,
metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."},
)
n_microbatches: int = field(
default=1,
metadata={"help": "Control the num of microbatches in one pp step."},
)

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -557,7 +564,11 @@ def main():

print("Final pre-training config:", config)

if "replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config and config.tensor_parallel_degree > 1 and config.to_static is False:
if (
"replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config
and config.tensor_parallel_degree > 1
and config.to_static is False
):
from llm.utils.replace_ops import replace_cross_entropy

replace_cross_entropy()
Expand Down
54 changes: 50 additions & 4 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddlenlp.trainer import Trainer

from ..transformers.model_utils import unwrap_model
from ..transformers import get_pp_schedule
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.env import (
PREFIX_CHECKPOINT_DIR,
Expand Down Expand Up @@ -89,16 +90,17 @@ def loss_func(loss, outputs):
self.auto_dist_config = kwargs.pop("auto_dist_config")
model = kwargs["model"]
for param in model.parameters():
# NOTE(zhangwl):in pipeline mode , param my be initialized before while delte init_func ,but param is still not is_initialized
# NOTE(zhangwl):in pipeline mode , param may be initialized before while delete init_func, but param is still not is_initialized
if not param._is_initialized() and param._init_func is not None:
param.initialize()
kwargs["model"] = model

super().__init__(*args, **kwargs)
assert self.args.enable_auto_parallel

self.global_mesh = fleet.auto.get_mesh()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
if self.args.pipeline_parallel_degree > 1:
self.pp_schedule = get_pp_schedule(model, self.args.n_microbatches, self.criterion, self.args.pipeline_schedule_mode, self.args.pipeline_parallel_degree, self.comm_group_in_pp)
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]

@classmethod
Expand All @@ -109,7 +111,7 @@ def parallel_model(cls, model, training_args: AutoTrainingArguments):
model (paddle.nn.Layer): the model to be parallelized.
training_args (AutoTrainingArguments) : Training arguments which contain distributed information
Returns:
the model after parallelize and config conatins distributed strategy
the model after parallelize and config contains distributed strategy
"""
if not training_args.use_intermediate_api:
return model, None
Expand Down Expand Up @@ -438,7 +440,7 @@ def _inner_training_loop(
)
assert (
paddle.sum(paddle.stack(global_step_list) - global_step_list[0]) == 0
), f"Error, get different globel step, please check! step list: {[x.item() for x in global_step_list]}"
), f"Error, get different global step, please check! step list: {[x.item() for x in global_step_list]}"

epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
Expand Down Expand Up @@ -703,7 +705,51 @@ def to_list(value):

return (loss, outputs) if return_outputs else loss

def compute_pipeline_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.criterion is not None:
if "labels" in inputs:
labels = inputs.pop("labels")
elif "start_positions" in inputs and "end_positions" in inputs:
labels = (inputs.pop("start_positions"), inputs.pop("end_positions"))
elif self.args.label_names is not None:
labels = []
for label in self.label_names:
labels.append(inputs.pop(label))
labels = tuple(labels)
elif "generator_labels" in inputs:
labels = inputs["generator_labels"]
else:
labels = None

pp_rank = self.comm_group_in_pp.rank
losses = []
if pp_rank == 0: # 第一个pp_stage,参数传入数据流
self.pp_schedule.step(**inputs) # 最后的pp_stage,参数传入label, 并输出loss
elif pp_rank == self.args.pipeline_parallel_degree - 1:
self.pp_schedule.step(target=labels, losses = losses)
else:
self.pp_schedule.step()

final_loss = None
if len(losses) != 0:
final_loss = paddle.stack(losses).mean()

return final_loss

def dynamic_pipeline_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
assert self.args.pipeline_parallel_degree > 1, "pipeline_parallel_degree must be greater than 1."
with self.autocast_smart_context_manager():
loss = self.compute_pipeline_loss(model, inputs)

return loss

def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
if self.args.pipeline_parallel_degree > 1:
return self.dynamic_pipeline_training(model, inputs)
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .configuration import *
from .modeling import *
from .modeling_auto import *
from .modeling_auto_pp import *
from .modeling_network import *
from .modeling_pp import *
from .tokenizer import *
Expand Down
Loading
Loading