Skip to content

Commit 0dac2dd

Browse files
authored
Llama4 linearized (axolotl-ai-cloud#2502)
* llama4 support for linearized experts * clean up fsdp2 sharding to prevent hang * add yaml config * cleanup example [skip ci]
1 parent a6c0321 commit 0dac2dd

File tree

10 files changed

+386
-65
lines changed

10 files changed

+386
-65
lines changed
File renamed without changes.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
2+
model_type: Llama4ForConditionalGeneration
3+
# Automatically upload checkpoint and final model to HF
4+
# hub_model_id: username/custom_model_name
5+
6+
strict: false
7+
8+
# torch_compile: true
9+
plugins:
10+
- axolotl.integrations.liger.LigerPlugin
11+
12+
liger_glu_activation: true
13+
liger_rms_norm: true
14+
liger_layer_norm: true
15+
16+
llama4_linearized_experts: true
17+
load_in_4bit: true
18+
adapter: qlora
19+
lora_r: 32
20+
lora_alpha: 64
21+
lora_target_modules:
22+
- self_attn.q_proj
23+
- self_attn.k_proj
24+
- self_attn.v_proj
25+
- self_attn.o_proj
26+
- shared_expert.gate_proj
27+
- shared_expert.up_proj
28+
- shared_expert.down_proj
29+
# - experts.gate_projs.[0-9]+$
30+
# - experts.up_projs.[0-9]+$
31+
# - experts.down_projs.[0-9]+$
32+
lora_modules_to_save:
33+
- lm_head
34+
- embed_tokens
35+
36+
chat_template: llama4
37+
datasets:
38+
- path: mlabonne/FineTome-100k
39+
type: chat_template
40+
split: train[:20%]
41+
field_messages: conversations
42+
message_property_mappings:
43+
role: from
44+
content: value
45+
46+
dataset_prepared_path: last_run_prepared
47+
val_set_size: 0.0
48+
output_dir: ./outputs/out
49+
50+
sequence_len: 4096
51+
sample_packing: true
52+
pad_to_sequence_len: true
53+
54+
wandb_project:
55+
wandb_entity:
56+
wandb_watch:
57+
wandb_name:
58+
wandb_log_model:
59+
60+
gradient_accumulation_steps: 1
61+
micro_batch_size: 1
62+
num_epochs: 1
63+
optimizer: adamw_torch_fused
64+
lr_scheduler: cosine
65+
learning_rate: 2e-5
66+
67+
bf16: true
68+
tf32: true
69+
70+
logging_steps: 1
71+
flash_attention: true
72+
73+
warmup_steps: 100
74+
evals_per_epoch: 1
75+
saves_per_epoch: 1
76+
weight_decay: 0.0
77+
fsdp:
78+
- auto_wrap
79+
- full_shard
80+
fsdp_config:
81+
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
82+
fsdp_limit_all_gathers: true
83+
fsdp_sync_module_states: true
84+
fsdp_offload_params: true
85+
fsdp_use_orig_params: false
86+
fsdp_cpu_ram_efficient_loading: true
87+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
88+
fsdp_state_dict_type: FULL_STATE_DICT
89+
fsdp_sharding_strategy: FULL_SHARD
90+
fsdp_activation_checkpointing: true
91+
special_tokens:
92+
pad_token: <|finetune_right_pad_id|>
93+
eos_token: <|eot|>

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ mypy
44
types-requests
55
quartodoc
66
jupyter
7+
blobfile
8+
tiktoken

src/axolotl/monkeypatch/accelerate/__init__.py

Whitespace-only changes.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
3+
"""
4+
5+
import logging
6+
import sys
7+
8+
import torch
9+
10+
LOG = logging.getLogger(__name__)
11+
12+
13+
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
14+
"""
15+
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
16+
parameters from rank 0 to all other ranks. This function modifies the model in-place.
17+
18+
Args:
19+
accelerator (`Accelerator`): The accelerator instance
20+
model (`torch.nn.Module`): The model to load the state dict into
21+
full_sd (`dict`): The full state dict to load, can only be on rank 0
22+
"""
23+
import torch.distributed as dist
24+
from torch.distributed.tensor import distribute_tensor
25+
26+
LOG.info("Broadcasting full state dict to all ranks...")
27+
sharded_sd = model.state_dict()
28+
param_names = sorted(sharded_sd.keys())
29+
for param_name in param_names:
30+
mesh = sharded_sd[param_name].device_mesh
31+
if accelerator.is_main_process:
32+
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
33+
full_param = full_sd[param_name].detach().cuda()
34+
dist.broadcast(full_param, src=0, group=mesh.get_group())
35+
sharded_tensor = distribute_tensor(
36+
full_param, mesh, sharded_sd[param_name].placements
37+
)
38+
sharded_sd[param_name] = sharded_tensor
39+
else:
40+
# Prepare a tensor of matching shape and dtype
41+
full_tensor = torch.empty(
42+
sharded_sd[param_name].size(),
43+
device="cuda",
44+
dtype=sharded_sd[param_name].dtype,
45+
)
46+
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
47+
sharded_tensor = distribute_tensor(
48+
full_tensor, mesh, sharded_sd[param_name].placements
49+
)
50+
sharded_sd[param_name] = sharded_tensor
51+
52+
model.load_state_dict(sharded_sd)
53+
54+
55+
def patch_accelerate_fsdp_utils():
56+
from accelerate.utils import fsdp_utils
57+
58+
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
59+
setattr(
60+
sys.modules["accelerate.utils.fsdp_utils"],
61+
"fsdp2_load_full_state_dict",
62+
fsdp2_load_full_state_dict,
63+
)

src/axolotl/monkeypatch/lora_kernels.py

Lines changed: 113 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import logging
66
import types
7-
from typing import Type
7+
from typing import Generator, Tuple, Type
88

99
import torch
1010
from accelerate.logging import get_logger
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
200200
)
201201

202202

203+
def find_self_attn_in_layer(
204+
layer: nn.Module,
205+
) -> Generator[Tuple[nn.Module], None, None]:
206+
# general case of most models
207+
if hasattr(layer, "self_attn"):
208+
if all(
209+
hasattr(layer.self_attn, proj)
210+
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
211+
):
212+
yield layer.self_attn
213+
214+
215+
def find_mlp_in_layer(
216+
layer: nn.Module,
217+
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
218+
# general case of most models
219+
if hasattr(layer, "mlp"):
220+
if all(
221+
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
222+
):
223+
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
224+
# llama4 linearized experts
225+
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
226+
mlp = layer.feedforward.shared_expert
227+
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
228+
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
229+
if all(
230+
hasattr(layer.feedforward.experts, proj)
231+
for proj in ["gate_projs", "up_projs", "down_projs"]
232+
):
233+
for gate_proj, up_proj, down_proj in zip(
234+
layer.feedforward.experts.gate_projs,
235+
layer.feedforward.experts.up_projs,
236+
layer.feedforward.experts.down_projs,
237+
):
238+
yield gate_proj, up_proj, down_proj, FakeMLP(
239+
gate_proj, up_proj, down_proj
240+
)
241+
242+
203243
def apply_lora_kernel_patches(
204244
model: PeftModelForCausalLM, cfg: DictDefault
205245
) -> PeftModelForCausalLM:
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
286326
for layer in layers:
287327
# Add QKV, O fallback implementations to start
288328
# These will be overwritten later (if some conditions apply)
289-
layer.self_attn.apply_qkv = types.MethodType(
290-
original_apply_qkv, layer.self_attn
291-
)
292-
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
293-
294-
if cfg.lora_mlp_kernel:
295-
# MLP patching
296-
gate_proj = layer.mlp.gate_proj
297-
up_proj = layer.mlp.up_proj
298-
down_proj = layer.mlp.down_proj
299-
300-
can_patch_mlp = all(
301-
hasattr(proj, "lora_A")
302-
and getattr(proj, "base_layer", proj).bias is None
303-
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
304-
for proj in (gate_proj, up_proj, down_proj)
305-
)
306-
307-
if can_patch_mlp:
308-
apply_fn = APPLY_FN_MAPPING[activation]
309-
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
310-
else:
311-
LOG.warning_once(
312-
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
329+
for self_attn in find_self_attn_in_layer(layer):
330+
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
331+
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
332+
333+
if cfg.lora_qkv_kernel:
334+
# Query, key, value patching
335+
layer_modules = [
336+
getattr(self_attn, linear_proj)
337+
for linear_proj in ["q_proj", "k_proj", "v_proj"]
338+
]
339+
can_patch_qkv = all(
340+
hasattr(module, "lora_A")
341+
and getattr(module, "base_layer", module).bias is None
342+
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
343+
for module in layer_modules
313344
)
314-
if cfg.lora_qkv_kernel:
315-
# Query, key, value patching
316-
layer_modules = [
317-
getattr(layer.self_attn, linear_proj)
318-
for linear_proj in ["q_proj", "k_proj", "v_proj"]
319-
]
320-
can_patch_qkv = all(
321-
hasattr(module, "lora_A")
322-
and getattr(module, "base_layer", module).bias is None
323-
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
324-
for module in layer_modules
325-
)
326-
327-
if can_patch_qkv:
328-
# Add optimized implementation
329-
layer.self_attn.apply_qkv = types.MethodType(
330-
apply_lora_qkv, layer.self_attn
331-
)
332-
else:
333-
LOG.warning_once(
334-
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
335-
)
336-
if cfg.lora_o_kernel:
337-
# Output patching
338-
layer_modules = [
339-
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
340-
]
341-
can_patch_o = all(
342-
hasattr(module, "lora_A")
343-
and getattr(module, "base_layer", module).bias is None
344-
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
345-
for module in layer_modules
346-
)
347-
348-
if can_patch_o:
349-
layer.self_attn.apply_o = types.MethodType(
350-
apply_lora_o, layer.self_attn
345+
346+
if can_patch_qkv:
347+
# Add optimized implementation
348+
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
349+
else:
350+
LOG.warning_once(
351+
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
352+
)
353+
if cfg.lora_o_kernel:
354+
# Output patching
355+
layer_modules = [
356+
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
357+
]
358+
can_patch_o = all(
359+
hasattr(module, "lora_A")
360+
and getattr(module, "base_layer", module).bias is None
361+
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
362+
for module in layer_modules
351363
)
352-
else:
353-
LOG.warning_once(
354-
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
364+
365+
if can_patch_o:
366+
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
367+
else:
368+
LOG.warning_once(
369+
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
370+
)
371+
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
372+
if cfg.lora_mlp_kernel:
373+
# MLP patching
374+
can_patch_mlp = all(
375+
hasattr(proj, "lora_A")
376+
and getattr(proj, "base_layer", proj).bias is None
377+
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
378+
for proj in (gate_proj, up_proj, down_proj)
355379
)
356380

381+
if can_patch_mlp:
382+
apply_fn = APPLY_FN_MAPPING[activation]
383+
layer.mlp.forward = types.MethodType(apply_fn, mlp)
384+
else:
385+
LOG.warning_once(
386+
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
387+
)
388+
357389
LOG.setLevel(original_level)
358390

359391
return model
392+
393+
394+
class FakeMLP(nn.Module):
395+
"""
396+
placeholder MLP for triton patching
397+
"""
398+
399+
gate_proj: nn.Linear
400+
up_proj: nn.Linear
401+
down_proj: nn.Linear
402+
403+
def __init__(self, gate_proj, up_proj, down_proj):
404+
super().__init__()
405+
self.gate_proj = gate_proj
406+
self.up_proj = up_proj
407+
self.down_proj = down_proj

src/axolotl/monkeypatch/models/llama4/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)