Skip to content

Commit 8af29c6

Browse files
added support for Conv1d for DoRA (#2531)
DoRA now supports Conv1d layers and, notably, the check for how to deal with other than linear layers was softened from checking for 4 dimensions to now 3 dimensions since `Conv1d` layers have 3 elements instead of 4.
1 parent 6c48949 commit 8af29c6

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

method_comparison/MetaMathQA/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from typing import Any, Callable, ContextManager, Literal, Optional
3232

3333
import torch
34+
from data import (
35+
get_train_valid_test_datasets,
36+
)
3437
from torch import nn
3538
from torch.amp import GradScaler, autocast
3639
from tqdm import tqdm
@@ -52,9 +55,6 @@
5255
validate_experiment_path,
5356
)
5457

55-
from data import (
56-
get_train_valid_test_datasets,
57-
)
5858
from peft import AdaLoraConfig, PeftConfig
5959
from peft.utils import SAFETENSORS_WEIGHTS_NAME
6060

src/peft/tuners/lora/dora.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
4848
base_layer = deepcopy(base_layer)
4949

5050
weight = dequantize_module_weight(base_layer)
51-
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers.
51+
if weight.data.ndim >= 3: # For handling LoRAs applied to Conv layers.
5252
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
5353
lora_weight = lora_weight.reshape(weight.shape)
5454
else:
@@ -183,6 +183,12 @@ def __repr__(self) -> str:
183183
return "lora.dora." + rep
184184

185185

186+
class DoraConv1dLayer(_DoraConvNdLayer):
187+
def __init__(self, fan_in_fan_out):
188+
super().__init__(fan_in_fan_out)
189+
self.conv_fn = F.conv1d
190+
191+
186192
class DoraConv2dLayer(_DoraConvNdLayer):
187193
def __init__(self, fan_in_fan_out):
188194
super().__init__(fan_in_fan_out)

src/peft/tuners/lora/variants.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from peft.utils.other import transpose
2323

24-
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
25-
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
24+
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
25+
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
2626

2727

2828
class DoraLinearVariant(LoraVariant):
@@ -296,6 +296,13 @@ def forward(module: _ConvNd, active_adapter: str, x: torch.Tensor, result: torch
296296
return result
297297

298298

299+
class DoraConv1dVariant(_DoraConvNdVariant):
300+
@staticmethod
301+
def init(module: Conv1d, adapter_name: str, **kwargs: Any) -> None:
302+
dora_layer = DoraConv1dLayer(fan_in_fan_out=False)
303+
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
304+
305+
299306
class DoraConv2dVariant(_DoraConvNdVariant):
300307
@staticmethod
301308
def init(module: Conv2d, adapter_name: str, **kwargs: Any) -> None:

tests/test_custom_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
{"target_modules": ["conv1d"], "trainable_token_indices": {"emb": [0, 10]}},
114114
),
115115
("Conv1d LoRA", "Conv1d", LoraConfig, {"target_modules": ["conv1d"]}),
116+
("Conv1d LoRA with DoRA", "Conv1d", LoraConfig, {"target_modules": ["conv1d"], "use_dora": True}),
116117
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
117118
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
118119
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),

0 commit comments

Comments
 (0)