|
4 | 4 | import inspect
|
5 | 5 | import logging
|
6 | 6 | import types
|
7 |
| -from typing import Type |
| 7 | +from typing import Generator, Tuple, Type |
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from accelerate.logging import get_logger
|
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
|
200 | 200 | )
|
201 | 201 |
|
202 | 202 |
|
| 203 | +def find_self_attn_in_layer( |
| 204 | + layer: nn.Module, |
| 205 | +) -> Generator[Tuple[nn.Module], None, None]: |
| 206 | + # general case of most models |
| 207 | + if hasattr(layer, "self_attn"): |
| 208 | + if all( |
| 209 | + hasattr(layer.self_attn, proj) |
| 210 | + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] |
| 211 | + ): |
| 212 | + yield layer.self_attn |
| 213 | + |
| 214 | + |
| 215 | +def find_mlp_in_layer( |
| 216 | + layer: nn.Module, |
| 217 | +) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]: |
| 218 | + # general case of most models |
| 219 | + if hasattr(layer, "mlp"): |
| 220 | + if all( |
| 221 | + hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"] |
| 222 | + ): |
| 223 | + yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp |
| 224 | + # llama4 linearized experts |
| 225 | + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"): |
| 226 | + mlp = layer.feedforward.shared_expert |
| 227 | + yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp |
| 228 | + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"): |
| 229 | + if all( |
| 230 | + hasattr(layer.feedforward.experts, proj) |
| 231 | + for proj in ["gate_projs", "up_projs", "down_projs"] |
| 232 | + ): |
| 233 | + for gate_proj, up_proj, down_proj in zip( |
| 234 | + layer.feedforward.experts.gate_projs, |
| 235 | + layer.feedforward.experts.up_projs, |
| 236 | + layer.feedforward.experts.down_projs, |
| 237 | + ): |
| 238 | + yield gate_proj, up_proj, down_proj, FakeMLP( |
| 239 | + gate_proj, up_proj, down_proj |
| 240 | + ) |
| 241 | + |
| 242 | + |
203 | 243 | def apply_lora_kernel_patches(
|
204 | 244 | model: PeftModelForCausalLM, cfg: DictDefault
|
205 | 245 | ) -> PeftModelForCausalLM:
|
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
|
286 | 326 | for layer in layers:
|
287 | 327 | # Add QKV, O fallback implementations to start
|
288 | 328 | # These will be overwritten later (if some conditions apply)
|
289 |
| - layer.self_attn.apply_qkv = types.MethodType( |
290 |
| - original_apply_qkv, layer.self_attn |
291 |
| - ) |
292 |
| - layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn) |
293 |
| - |
294 |
| - if cfg.lora_mlp_kernel: |
295 |
| - # MLP patching |
296 |
| - gate_proj = layer.mlp.gate_proj |
297 |
| - up_proj = layer.mlp.up_proj |
298 |
| - down_proj = layer.mlp.down_proj |
299 |
| - |
300 |
| - can_patch_mlp = all( |
301 |
| - hasattr(proj, "lora_A") |
302 |
| - and getattr(proj, "base_layer", proj).bias is None |
303 |
| - and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 |
304 |
| - for proj in (gate_proj, up_proj, down_proj) |
305 |
| - ) |
306 |
| - |
307 |
| - if can_patch_mlp: |
308 |
| - apply_fn = APPLY_FN_MAPPING[activation] |
309 |
| - layer.mlp.forward = types.MethodType(apply_fn, layer.mlp) |
310 |
| - else: |
311 |
| - LOG.warning_once( |
312 |
| - "Cannot patch some MLP layers - requires LoRA adapters with no bias" |
| 329 | + for self_attn in find_self_attn_in_layer(layer): |
| 330 | + self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn) |
| 331 | + self_attn.apply_o = types.MethodType(original_apply_o, self_attn) |
| 332 | + |
| 333 | + if cfg.lora_qkv_kernel: |
| 334 | + # Query, key, value patching |
| 335 | + layer_modules = [ |
| 336 | + getattr(self_attn, linear_proj) |
| 337 | + for linear_proj in ["q_proj", "k_proj", "v_proj"] |
| 338 | + ] |
| 339 | + can_patch_qkv = all( |
| 340 | + hasattr(module, "lora_A") |
| 341 | + and getattr(module, "base_layer", module).bias is None |
| 342 | + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 |
| 343 | + for module in layer_modules |
313 | 344 | )
|
314 |
| - if cfg.lora_qkv_kernel: |
315 |
| - # Query, key, value patching |
316 |
| - layer_modules = [ |
317 |
| - getattr(layer.self_attn, linear_proj) |
318 |
| - for linear_proj in ["q_proj", "k_proj", "v_proj"] |
319 |
| - ] |
320 |
| - can_patch_qkv = all( |
321 |
| - hasattr(module, "lora_A") |
322 |
| - and getattr(module, "base_layer", module).bias is None |
323 |
| - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 |
324 |
| - for module in layer_modules |
325 |
| - ) |
326 |
| - |
327 |
| - if can_patch_qkv: |
328 |
| - # Add optimized implementation |
329 |
| - layer.self_attn.apply_qkv = types.MethodType( |
330 |
| - apply_lora_qkv, layer.self_attn |
331 |
| - ) |
332 |
| - else: |
333 |
| - LOG.warning_once( |
334 |
| - "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" |
335 |
| - ) |
336 |
| - if cfg.lora_o_kernel: |
337 |
| - # Output patching |
338 |
| - layer_modules = [ |
339 |
| - getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] |
340 |
| - ] |
341 |
| - can_patch_o = all( |
342 |
| - hasattr(module, "lora_A") |
343 |
| - and getattr(module, "base_layer", module).bias is None |
344 |
| - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 |
345 |
| - for module in layer_modules |
346 |
| - ) |
347 |
| - |
348 |
| - if can_patch_o: |
349 |
| - layer.self_attn.apply_o = types.MethodType( |
350 |
| - apply_lora_o, layer.self_attn |
| 345 | + |
| 346 | + if can_patch_qkv: |
| 347 | + # Add optimized implementation |
| 348 | + self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) |
| 349 | + else: |
| 350 | + LOG.warning_once( |
| 351 | + "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" |
| 352 | + ) |
| 353 | + if cfg.lora_o_kernel: |
| 354 | + # Output patching |
| 355 | + layer_modules = [ |
| 356 | + getattr(self_attn, linear_proj) for linear_proj in ["o_proj"] |
| 357 | + ] |
| 358 | + can_patch_o = all( |
| 359 | + hasattr(module, "lora_A") |
| 360 | + and getattr(module, "base_layer", module).bias is None |
| 361 | + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 |
| 362 | + for module in layer_modules |
351 | 363 | )
|
352 |
| - else: |
353 |
| - LOG.warning_once( |
354 |
| - "Cannot patch some attention output projection - requires LoRA adapters with no bias" |
| 364 | + |
| 365 | + if can_patch_o: |
| 366 | + self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) |
| 367 | + else: |
| 368 | + LOG.warning_once( |
| 369 | + "Cannot patch some attention output projection - requires LoRA adapters with no bias" |
| 370 | + ) |
| 371 | + for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): |
| 372 | + if cfg.lora_mlp_kernel: |
| 373 | + # MLP patching |
| 374 | + can_patch_mlp = all( |
| 375 | + hasattr(proj, "lora_A") |
| 376 | + and getattr(proj, "base_layer", proj).bias is None |
| 377 | + and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 |
| 378 | + for proj in (gate_proj, up_proj, down_proj) |
355 | 379 | )
|
356 | 380 |
|
| 381 | + if can_patch_mlp: |
| 382 | + apply_fn = APPLY_FN_MAPPING[activation] |
| 383 | + layer.mlp.forward = types.MethodType(apply_fn, mlp) |
| 384 | + else: |
| 385 | + LOG.warning_once( |
| 386 | + "Cannot patch some MLP layers - requires LoRA adapters with no bias" |
| 387 | + ) |
| 388 | + |
357 | 389 | LOG.setLevel(original_level)
|
358 | 390 |
|
359 | 391 | return model
|
| 392 | + |
| 393 | + |
| 394 | +class FakeMLP(nn.Module): |
| 395 | + """ |
| 396 | + placeholder MLP for triton patching |
| 397 | + """ |
| 398 | + |
| 399 | + gate_proj: nn.Linear |
| 400 | + up_proj: nn.Linear |
| 401 | + down_proj: nn.Linear |
| 402 | + |
| 403 | + def __init__(self, gate_proj, up_proj, down_proj): |
| 404 | + super().__init__() |
| 405 | + self.gate_proj = gate_proj |
| 406 | + self.up_proj = up_proj |
| 407 | + self.down_proj = down_proj |
0 commit comments