@@ -114,6 +114,28 @@ def onload_layer(layer):
114
114
layer .base_layer ._hf_hook .post_forward (layer .base_layer , torch .tensor ([]))
115
115
116
116
117
+ def _check_lora_target_modules_mamba (peft_config : PeftConfig , model : nn .Module , target_name : str ):
118
+ """
119
+ Prevent applying LoRA to incompatible modules in specific architectures (e.g., Mamba).
120
+ """
121
+
122
+ lora_like_types = {"LORA" , "ADALORA" , "XLORA" , "RANDLORA" }
123
+ incompatible_modules = {"out_proj" , "conv1d" }
124
+ mamba_model_types = {"falcon_h1" , "mamba" , "mamba2" , "falcon_mamba" }
125
+
126
+ if (
127
+ peft_config .peft_type in lora_like_types
128
+ and hasattr (model , "config" )
129
+ and getattr (model .config , "model_type" , None ) in mamba_model_types
130
+ ):
131
+ if target_name in incompatible_modules :
132
+ raise ValueError (
133
+ f"[PEFT:{ peft_config .peft_type } ] Module '{ target_name } ' is incompatible with Mamba-based models "
134
+ f"(model_type='{ model .config .model_type } '). Incompatible modules: { incompatible_modules } . "
135
+ "Please remove it from `target_modules` to avoid compatibility issues."
136
+ )
137
+
138
+
117
139
class BaseTuner (nn .Module , ABC ):
118
140
r"""
119
141
A base tuner model that provides the common methods and attributes for all tuners that are injectable into a
@@ -398,6 +420,12 @@ def _check_merge_allowed(self):
398
420
+ example_code
399
421
)
400
422
423
+ def _check_target_module_compatiblity (self , peft_config : PeftConfig , model : nn .Module , target_name : str ):
424
+ """
425
+ Prevent applying LoRA to incompatible modules in specific architectures (e.g., Mamba).
426
+ """
427
+ _check_lora_target_modules_mamba (peft_config , model , target_name )
428
+
401
429
def inject_adapter (
402
430
self , model : nn .Module , adapter_name : str , autocast_adapter_dtype : bool = True , low_cpu_mem_usage : bool = False
403
431
) -> None :
@@ -497,6 +525,7 @@ def inject_adapter(
497
525
else :
498
526
self .targeted_module_names .append (key )
499
527
parent , target , target_name = _get_submodules (model , key )
528
+ self ._check_target_module_compatiblity (peft_config , model , target_name )
500
529
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
501
530
with ctx ():
502
531
self ._create_and_replace (peft_config , adapter_name , target , target_name , parent , current_key = key )
0 commit comments