|
40 | 40 | pass |
41 | 41 |
|
42 | 42 | from paddlenlp.generation import GenerationConfig, TextIteratorStreamer |
43 | | -from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM |
| 43 | +from paddlenlp.peft import ( |
| 44 | + LoRAConfig, |
| 45 | + LoRAModel, |
| 46 | + PrefixConfig, |
| 47 | + PrefixModelForCausalLM, |
| 48 | + TAREModel, |
| 49 | +) |
44 | 50 | from paddlenlp.taskflow.utils import static_mode_guard |
45 | 51 | from paddlenlp.trainer import PdArgumentParser |
46 | 52 | from paddlenlp.transformers import ( |
@@ -85,6 +91,9 @@ class PredictorArgument: |
85 | 91 | device: str = field(default="gpu", metadata={"help": "Device"}) |
86 | 92 | dtype: str = field(default=None, metadata={"help": "Model dtype"}) |
87 | 93 | lora_path: str = field(default=None, metadata={"help": "The directory of LoRA parameters. Default to None"}) |
| 94 | + tare_path: str = field(default=None, metadata={"help": "The directory of TARE parameters. Default to None"}) |
| 95 | + tare_n: int = field(default=8, metadata={"help": "The num of TARE editors. Default to 8."}) |
| 96 | + tare_k: int = field(default=7, metadata={"help": "The num of TARE selected editors. Default to 7."}) |
88 | 97 | export_precache: bool = field(default=False, metadata={"help": "whether use prefix weight to do infer"}) |
89 | 98 | prefix_path: str = field( |
90 | 99 | default=None, metadata={"help": "The directory of Prefix Tuning parameters. Default to None"} |
@@ -355,6 +364,11 @@ def __init__( |
355 | 364 | prefix_path=config.prefix_path, |
356 | 365 | postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"], |
357 | 366 | ) |
| 367 | + |
| 368 | + if config.tare_path is not None: |
| 369 | + self.model = TAREModel(base_model=self.model, n=config.tare_n, k=config.tare_k) |
| 370 | + self.model.load_model(os.path.join(config.tare_path, "delta_vector.pth")) |
| 371 | + |
358 | 372 | self.model.eval() |
359 | 373 |
|
360 | 374 | @paddle.no_grad() |
|
0 commit comments