Skip to content

Commit 29f97bb

Browse files
TARE提交 (#11151)
Co-authored-by: zsy <zsy03260058@163.com>
1 parent 10014d8 commit 29f97bb

File tree

6 files changed

+383
-1
lines changed

6 files changed

+383
-1
lines changed

llm/predict/predictor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
pass
4141

4242
from paddlenlp.generation import GenerationConfig, TextIteratorStreamer
43-
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
43+
from paddlenlp.peft import (
44+
LoRAConfig,
45+
LoRAModel,
46+
PrefixConfig,
47+
PrefixModelForCausalLM,
48+
TAREModel,
49+
)
4450
from paddlenlp.taskflow.utils import static_mode_guard
4551
from paddlenlp.trainer import PdArgumentParser
4652
from paddlenlp.transformers import (
@@ -85,6 +91,9 @@ class PredictorArgument:
8591
device: str = field(default="gpu", metadata={"help": "Device"})
8692
dtype: str = field(default=None, metadata={"help": "Model dtype"})
8793
lora_path: str = field(default=None, metadata={"help": "The directory of LoRA parameters. Default to None"})
94+
tare_path: str = field(default=None, metadata={"help": "The directory of TARE parameters. Default to None"})
95+
tare_n: int = field(default=8, metadata={"help": "The num of TARE editors. Default to 8."})
96+
tare_k: int = field(default=7, metadata={"help": "The num of TARE selected editors. Default to 7."})
8897
export_precache: bool = field(default=False, metadata={"help": "whether use prefix weight to do infer"})
8998
prefix_path: str = field(
9099
default=None, metadata={"help": "The directory of Prefix Tuning parameters. Default to None"}
@@ -355,6 +364,11 @@ def __init__(
355364
prefix_path=config.prefix_path,
356365
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
357366
)
367+
368+
if config.tare_path is not None:
369+
self.model = TAREModel(base_model=self.model, n=config.tare_n, k=config.tare_k)
370+
self.model.load_model(os.path.join(config.tare_path, "delta_vector.pth"))
371+
358372
self.model.eval()
359373

360374
@paddle.no_grad()

llm/run_finetune.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
LoRAModel,
4040
PrefixConfig,
4141
PrefixModelForCausalLM,
42+
TAREModel,
4243
VeRAConfig,
4344
VeRAModel,
4445
)
@@ -500,6 +501,10 @@ def compute_metrics_do_generation(eval_preds):
500501
elif last_checkpoint is not None:
501502
checkpoint = last_checkpoint
502503
train_result = trainer.train(resume_from_checkpoint=checkpoint)
504+
505+
if model_args.tare:
506+
model.save_model(os.path.join(training_args.output_dir, "delta_vector.pth"))
507+
503508
if model_args.neftune:
504509
neft_post_hook_handle.remove()
505510
if training_args.benchmark:
@@ -725,6 +730,10 @@ def create_peft_model(
725730
model.mark_only_vera_as_trainable(notfreezeB=True)
726731
model.print_trainable_parameters()
727732

733+
if model_args.tare:
734+
model = TAREModel(base_model=model, n=model_args.tare_n, k=model_args.tare_k)
735+
model.print_trainable_parameters()
736+
728737
return model
729738

730739

paddlenlp/peft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .lora import LoRAAutoConfig, LoRAAutoModel, LoRAConfig, LoRAModel
1818
from .prefix import PrefixConfig, PrefixModelForCausalLM
1919
from .reft import ReFTModel
20+
from .tare import TAREModel
2021
from .vera import VeRAConfig, VeRAModel

paddlenlp/peft/tare/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .tare_model import TAREModel
16+
17+
__all__ = ["TAREModel"]

0 commit comments

Comments
 (0)