Skip to content

Commit eb9337e

Browse files
committed
Clean up LoRA
1 parent dc47621 commit eb9337e

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

labml_nn/lora/__init__.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1+
"""
2+
# LoRA
3+
"""
4+
15
import torch
26
import torch.nn as nn
37

48

59
class Linear(nn.Module):
6-
def __init__(
7-
self,
8-
in_features: int,
9-
out_features: int,
10-
bias: bool,
11-
r: int,
12-
alpha: int = None):
10+
def __init__(self, in_features: int, out_features: int, bias: bool,
11+
r: int, alpha: int = None):
12+
super().__init__()
1313
if alpha is None:
1414
alpha = r
15-
super().__init__()
1615
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
1716
self.weight.requires_grad = False
1817

@@ -39,16 +38,11 @@ def forward(self, x: torch.Tensor):
3938

4039

4140
class Embedding(nn.Module):
42-
def __init__(
43-
self,
44-
num_embeddings: int,
45-
embedding_dim: int,
46-
r: int,
47-
alpha: int = None,
48-
):
41+
def __init__(self, num_embeddings: int, embedding_dim: int,
42+
r: int, alpha: int = None):
43+
super().__init__()
4944
if alpha is None:
5045
alpha = r
51-
super().__init__()
5246

5347
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
5448
self.weight.requires_grad = False

0 commit comments

Comments
 (0)