Skip to content

Commit 93ddb10

Browse files
authored
FIX Use SFTConfig instead of SFTTrainer keyword args (#2150)
Update training script using trl to fix deprecations in argument usage.
1 parent c039b00 commit 93ddb10

File tree

6 files changed

+20
-51
lines changed

6 files changed

+20
-51
lines changed

docs/source/accelerate/deepspeed.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la
128128
129129
Let's dive a little deeper into the script so you can see what's going on, and understand how it works.
130130
131-
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, `SFTTrainer` internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
131+
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, [`~trl.SFTTrainer`] internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
132132
133133
```python
134134
# trainer
@@ -139,13 +139,6 @@ trainer = SFTTrainer(
139139
train_dataset=train_dataset,
140140
eval_dataset=eval_dataset,
141141
peft_config=peft_config,
142-
packing=data_args.packing,
143-
dataset_kwargs={
144-
"append_concat_token": data_args.append_concat_token,
145-
"add_special_tokens": data_args.add_special_tokens,
146-
},
147-
dataset_text_field=data_args.dataset_text_field,
148-
max_seq_length=data_args.max_seq_length,
149142
)
150143
trainer.accelerator.print(f"{trainer.model}")
151144

docs/source/accelerate/fsdp.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la
108108

109109
Let's dive a little deeper into the script so you can see what's going on, and understand how it works.
110110

111-
The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:
111+
The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:
112112

113113
```python
114114
# trainer
@@ -119,13 +119,6 @@ trainer = SFTTrainer(
119119
train_dataset=train_dataset,
120120
eval_dataset=eval_dataset,
121121
peft_config=peft_config,
122-
packing=data_args.packing,
123-
dataset_kwargs={
124-
"append_concat_token": data_args.append_concat_token,
125-
"add_special_tokens": data_args.add_special_tokens,
126-
},
127-
dataset_text_field=data_args.dataset_text_field,
128-
max_seq_length=data_args.max_seq_length,
129122
)
130123
trainer.accelerator.print(f"{trainer.model}")
131124
if model_args.use_peft_lora:

examples/olora_finetuning/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from peft import LoraConfig, get_peft_model
1010
from transformers import AutoTokenizer, AutoModelForCausalLM
11-
from trl import SFTTrainer
11+
from trl import SFTConfig, SFTTrainer
1212
from datasets import load_dataset
1313

1414
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
@@ -18,11 +18,10 @@ lora_config = LoraConfig(
1818
init_lora_weights="olora"
1919
)
2020
peft_model = get_peft_model(model, lora_config)
21+
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
2122
trainer = SFTTrainer(
2223
model=peft_model,
2324
train_dataset=dataset,
24-
dataset_text_field="text",
25-
max_seq_length=512,
2625
tokenizer=tokenizer,
2726
)
2827
trainer.train()

examples/pissa_finetuning/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the p
66
```python
77
import torch
88
from peft import LoraConfig, get_peft_model
9-
from transformers import AutoTokenizer, AutoModelForCausalLM
10-
from trl import SFTTrainer
9+
from transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import SFTConfig, SFTTrainer
1110
from datasets import load_dataset
1211

1312
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
@@ -23,11 +22,11 @@ peft_model.print_trainable_parameters()
2322

2423
dataset = load_dataset("imdb", split="train[:1%]")
2524

25+
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
2626
trainer = SFTTrainer(
2727
model=peft_model,
28+
args=training_args,
2829
train_dataset=dataset,
29-
dataset_text_field="text",
30-
max_seq_length=128,
3130
tokenizer=tokenizer,
3231
)
3332
trainer.train()

examples/pissa_finetuning/pissa_finetuning.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
import os
1616
from dataclasses import dataclass, field
17-
from typing import List, Optional
17+
from typing import Optional
1818

1919
import torch
2020
from datasets import load_dataset
21-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
22-
from trl import SFTTrainer
21+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
22+
from trl import SFTConfig, SFTTrainer
2323

2424
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
2525

2626

2727
@dataclass
28-
class TrainingArguments(TrainingArguments):
28+
class ScriptArguments(SFTConfig):
2929
# model configs
3030
base_model_name_or_path: Optional[str] = field(
3131
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
@@ -46,14 +46,9 @@ class TrainingArguments(TrainingArguments):
4646
# dataset configs
4747
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
4848
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
49-
dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
50-
max_seq_length: int = field(
51-
default=512,
52-
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
53-
)
5449

5550

56-
parser = HfArgumentParser(TrainingArguments)
51+
parser = HfArgumentParser(ScriptArguments)
5752
script_args = parser.parse_args_into_dataclasses()[0]
5853
print(script_args)
5954

@@ -133,8 +128,6 @@ class TrainingArguments(TrainingArguments):
133128
model=peft_model,
134129
args=script_args,
135130
train_dataset=dataset,
136-
dataset_text_field="text",
137-
max_seq_length=script_args.max_seq_length,
138131
tokenizer=tokenizer,
139132
)
140133
trainer.train()

examples/sft/train.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from dataclasses import dataclass, field
44
from typing import Optional
55

6-
from transformers import HfArgumentParser, TrainingArguments, set_seed
7-
from trl import SFTTrainer
6+
from transformers import HfArgumentParser, set_seed
7+
from trl import SFTConfig, SFTTrainer
88
from utils import create_and_prepare_model, create_datasets
99

1010

@@ -79,12 +79,6 @@ class DataTrainingArguments:
7979
default="timdettmers/openassistant-guanaco",
8080
metadata={"help": "The preference dataset to use."},
8181
)
82-
packing: Optional[bool] = field(
83-
default=False,
84-
metadata={"help": "Use packing dataset creating."},
85-
)
86-
dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
87-
max_seq_length: Optional[int] = field(default=512)
8882
append_concat_token: Optional[bool] = field(
8983
default=False,
9084
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
@@ -112,6 +106,11 @@ def main(model_args, data_args, training_args):
112106
if training_args.gradient_checkpointing:
113107
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
114108

109+
training_args.dataset_kwargs = {
110+
"append_concat_token": data_args.append_concat_token,
111+
"add_special_tokens": data_args.add_special_tokens,
112+
}
113+
115114
# datasets
116115
train_dataset, eval_dataset = create_datasets(
117116
tokenizer,
@@ -128,13 +127,6 @@ def main(model_args, data_args, training_args):
128127
train_dataset=train_dataset,
129128
eval_dataset=eval_dataset,
130129
peft_config=peft_config,
131-
packing=data_args.packing,
132-
dataset_kwargs={
133-
"append_concat_token": data_args.append_concat_token,
134-
"add_special_tokens": data_args.add_special_tokens,
135-
},
136-
dataset_text_field=data_args.dataset_text_field,
137-
max_seq_length=data_args.max_seq_length,
138130
)
139131
trainer.accelerator.print(f"{trainer.model}")
140132
if hasattr(trainer.model, "print_trainable_parameters"):
@@ -153,7 +145,7 @@ def main(model_args, data_args, training_args):
153145

154146

155147
if __name__ == "__main__":
156-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
148+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
157149
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
158150
# If we pass only one argument to the script and it's the path to a json file,
159151
# let's parse it to get our arguments.

0 commit comments

Comments
 (0)