Skip to content

Commit 6c054d0

Browse files
Method comparison: Support more options for the optimizer (#2479)
Allow setting a different optimizer, including PEFT specific ones like LoRA+. Add experiment for LoRA-FA Update param name, rm obsolete directories
1 parent eb5e9bc commit 6c054d0

File tree

8 files changed

+101
-29
lines changed

8 files changed

+101
-29
lines changed

method_comparison/MetaMathQA/default_training_params.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"compile": false,
1010
"seed": 0,
1111
"grad_norm_clip": 1.0,
12+
"optimizer_type": "AdamW",
1213
"optimizer_kwargs": {
1314
"lr": 1e-4,
1415
"weight_decay": 0.1
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"alpha_pattern": {},
3+
"auto_mapping": null,
4+
"base_model_name_or_path": null,
5+
"bias": "none",
6+
"corda_config": null,
7+
"eva_config": null,
8+
"exclude_modules": null,
9+
"fan_in_fan_out": false,
10+
"inference_mode": false,
11+
"init_lora_weights": true,
12+
"layer_replication": null,
13+
"layers_pattern": null,
14+
"layers_to_transform": null,
15+
"loftq_config": {},
16+
"lora_alpha": 64,
17+
"lora_bias": false,
18+
"lora_dropout": 0.0,
19+
"megatron_config": null,
20+
"megatron_core": "megatron.core",
21+
"modules_to_save": null,
22+
"peft_type": "LORA",
23+
"r": 32,
24+
"rank_pattern": {},
25+
"revision": null,
26+
"target_modules": null,
27+
"task_type": "CAUSAL_LM",
28+
"use_dora": false,
29+
"use_rslora": false
30+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"optimizer_type": "lora-fa",
3+
"optimizer_kwargs": {
4+
"r": 32,
5+
"lora_alpha": 64,
6+
"lr": 1e-4,
7+
"weight_decay": 0.1
8+
}
9+
}

method_comparison/MetaMathQA/run.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torch import nn
3535
from torch.amp import GradScaler, autocast
3636
from tqdm import tqdm
37-
from transformers import GenerationConfig, get_cosine_schedule_with_warmup, set_seed
37+
from transformers import GenerationConfig, set_seed
3838
from utils import (
3939
FILE_NAME_TRAIN_PARAMS,
4040
BucketIterator,
@@ -44,6 +44,7 @@
4444
get_base_model_info,
4545
get_dataset_info,
4646
get_model,
47+
get_optimizer_and_scheduler,
4748
get_tokenizer,
4849
get_train_config,
4950
init_cuda,
@@ -63,7 +64,6 @@
6364

6465
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
6566
# if lr scheduler with warmup is used, the ratio of warmup steps to total steps
66-
WARMUP_STEP_RATIO = 0.1
6767
BUCKET_FACTOR = 20 # number of batches per bucket, increasing this further has diminishing returns
6868

6969

@@ -98,18 +98,6 @@ def evaluate(model, tokenizer, ds, batch_size, generate_kwargs, use_tqdm: bool =
9898
return predictions, responses
9999

100100

101-
class DummyScheduler:
102-
# if no lr scheduler is being used
103-
def __init__(self, lr):
104-
self.lr = lr
105-
106-
def get_last_lr(self):
107-
return [self.lr]
108-
109-
def step(self):
110-
pass
111-
112-
113101
class DummyGradScaler:
114102
# if no mixed precision is being used
115103
def scale(self, loss):
@@ -136,6 +124,7 @@ def train(
136124
eval_steps: int,
137125
generation_kwargs: dict[str, Any],
138126
grad_norm_clip: float,
127+
optimizer_type: str,
139128
optimizer_kwargs: dict[str, Any],
140129
query_template: str,
141130
lr_scheduler_arg: Optional[Literal["cosine"]],
@@ -156,16 +145,20 @@ def train(
156145
else:
157146
grad_scaler = DummyGradScaler()
158147
autocast_ctx = nullcontext
159-
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
160-
if lr_scheduler_arg == "cosine":
161-
warmup_steps = int(WARMUP_STEP_RATIO * max_steps)
162-
lr_scheduler = get_cosine_schedule_with_warmup(
163-
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
164-
)
165-
elif lr_scheduler_arg is None:
166-
lr_scheduler = DummyScheduler(optimizer_kwargs["lr"])
167-
else:
168-
raise ValueError(f"Invalid lr_scheduler argument: {lr_scheduler_arg}")
148+
149+
optimizer, lr_scheduler = get_optimizer_and_scheduler(
150+
model,
151+
optimizer_type=optimizer_type,
152+
max_steps=max_steps,
153+
lr_scheduler_arg=lr_scheduler_arg,
154+
**optimizer_kwargs,
155+
)
156+
# print this after getting the optimizer, in case it modifies requires_gard
157+
num_trainable_params, num_params = model.get_nb_trainable_parameters()
158+
print_verbose(
159+
f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || "
160+
f"trainable: {100 * num_trainable_params / num_params:.4f}%"
161+
)
169162

170163
status = TrainStatus.FAILED
171164
tic_train = time.perf_counter()
@@ -371,11 +364,6 @@ def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
371364
autocast_adapter_dtype=train_config.autocast_adapter_dtype,
372365
)
373366
print_verbose(model)
374-
num_trainable_params, num_params = model.get_nb_trainable_parameters()
375-
print_verbose(
376-
f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || "
377-
f"trainable: {100 * num_trainable_params / num_params:.4f}%"
378-
)
379367

380368
# train model
381369
try:
@@ -389,6 +377,7 @@ def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
389377
eval_steps=train_config.eval_steps,
390378
generation_kwargs=train_config.generation_kwargs,
391379
grad_norm_clip=train_config.grad_norm_clip,
380+
optimizer_type=train_config.optimizer_type,
392381
optimizer_kwargs=train_config.optimizer_kwargs,
393382
query_template=train_config.query_template,
394383
lr_scheduler_arg=train_config.lr_scheduler,

method_comparison/MetaMathQA/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
AutoModelForCausalLM,
3838
AutoTokenizer,
3939
BitsAndBytesConfig,
40+
get_cosine_schedule_with_warmup,
4041
)
4142

4243
import peft
4344
from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
45+
from peft.optimizers import create_lorafa_optimizer, create_loraplus_optimizer
4446
from peft.utils import CONFIG_NAME
4547

4648

@@ -58,6 +60,7 @@
5860
# cancelled results
5961
RESULT_PATH_CANCELLED = os.path.join(os.path.dirname(__file__), "cancelled_results")
6062
hf_api = huggingface_hub.HfApi()
63+
WARMUP_STEP_RATIO = 0.1
6164

6265

6366
@dataclass
@@ -76,6 +79,7 @@ class TrainConfig:
7679
query_template: The template for the query
7780
seed: The random seed
7881
grad_norm_clip: The gradient norm clipping value (set to 0 to skip)
82+
optimizer_type: The name of a torch optimizer (e.g. AdamW) or a PEFT method ("lora+", "lora-fa")
7983
optimizer_kwargs: The optimizer keyword arguments (lr etc.)
8084
lr_scheduler: The learning rate scheduler (currently only None or 'cosine' are supported)
8185
use_amp: Whether to use automatic mixed precision
@@ -95,6 +99,7 @@ class TrainConfig:
9599
query_template: str
96100
seed: int
97101
grad_norm_clip: float # set to 0 to skip
102+
optimizer_type: str
98103
optimizer_kwargs: dict[str, Any]
99104
lr_scheduler: Optional[Literal["cosine"]]
100105
use_amp: bool
@@ -121,6 +126,8 @@ def __post_init__(self) -> None:
121126
raise ValueError(f"Invalid eval_steps: {self.eval_steps} > max_steps: {self.max_steps}")
122127
if self.grad_norm_clip < 0:
123128
raise ValueError(f"Invalid grad_norm_clip: {self.grad_norm_clip}")
129+
if self.optimizer_type not in ["lora+", "lora-fa"] and not hasattr(torch.optim, self.optimizer_type):
130+
raise ValueError(f"Invalid optimizer_type: {self.optimizer_type}")
124131
if self.lr_scheduler not in [None, "cosine"]:
125132
raise ValueError(f"Invalid lr_scheduler: {self.lr_scheduler}, must be None or 'cosine'")
126133
if "{query}" not in self.query_template:
@@ -246,6 +253,42 @@ def get_model(
246253
return model
247254

248255

256+
class DummyScheduler:
257+
# if no lr scheduler is being used
258+
def __init__(self, lr):
259+
self.lr = lr
260+
261+
def get_last_lr(self):
262+
return [self.lr]
263+
264+
def step(self):
265+
pass
266+
267+
268+
def get_optimizer_and_scheduler(
269+
model, *, optimizer_type: str, max_steps: int, lr_scheduler_arg: Optional[Literal["cosine"]], **optimizer_kwargs
270+
) -> tuple[torch.optim.Optimizer, Any]:
271+
if optimizer_type == "lora+":
272+
optimizer = create_loraplus_optimizer(model, optimizer_cls=torch.optim.AdamW, **optimizer_kwargs)
273+
elif optimizer_type == "lora-fa":
274+
optimizer = create_lorafa_optimizer(model, **optimizer_kwargs)
275+
else:
276+
cls = getattr(torch.optim, optimizer_type)
277+
optimizer = cls(model.parameters(), **optimizer_kwargs)
278+
279+
if lr_scheduler_arg == "cosine":
280+
warmup_steps = int(WARMUP_STEP_RATIO * max_steps)
281+
lr_scheduler = get_cosine_schedule_with_warmup(
282+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
283+
)
284+
elif lr_scheduler_arg is None:
285+
lr_scheduler = DummyScheduler(optimizer_kwargs["lr"])
286+
else:
287+
raise ValueError(f"Invalid lr_scheduler argument: {lr_scheduler_arg}")
288+
289+
return optimizer, lr_scheduler
290+
291+
249292
class BucketIterator:
250293
"""
251294
Iterator that yields batches of data from a torch Dataset, grouped in buckets by sequence length

method_comparison/experiments/.gitkeep

Whitespace-only changes.

method_comparison/results/.gitkeep

Whitespace-only changes.

method_comparison/temporary_results/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)