Skip to content

Commit dc47621

Browse files
committed
Clean up LoRA
1 parent 957ade6 commit dc47621

File tree

9 files changed

+78
-171
lines changed

9 files changed

+78
-171
lines changed

โ€Žlabml_nn/transformers/LoRA/GPT2.py renamed to โ€Žlabml_nn/lora/gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from transformers import AutoTokenizer
4-
from labml_nn.transformers.LoRA import Linear, Embedding
4+
from labml_nn.lora import Linear, Embedding
55

66
tokenizer = AutoTokenizer.from_pretrained("gpt2")
77

โ€Žlabml_nn/transformers/LoRA/train.ipynb renamed to โ€Žlabml_nn/lora/train.ipynb

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
{
22
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "code",
6+
"outputs": [],
7+
"execution_count": null,
8+
"source": [
9+
"import torch\n",
10+
"from torch.optim import Adam\n",
11+
"from torch.utils.data import DataLoader, TensorDataset\n",
12+
"from torch.utils.data import random_split\n",
13+
"from transformers import AutoTokenizer\n",
14+
"\n",
15+
"from labml import tracker, experiment\n",
16+
"from labml_nn.lora.gpt2 import GPTModel"
17+
],
18+
"id": "f072832ec9d346e1"
19+
},
320
{
421
"cell_type": "code",
522
"id": "initial_id",
@@ -29,8 +46,6 @@
2946
"id": "ac8e51ae5bbfcae7",
3047
"metadata": {},
3148
"source": [
32-
"from transformers import AutoTokenizer\n",
33-
"\n",
3449
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
3550
"\n",
3651
"tokens = tokenizer.encode(text, add_special_tokens=False)"
@@ -64,11 +79,7 @@
6479
"cell_type": "code",
6580
"id": "5c4cc78ac1a02c1d",
6681
"metadata": {},
67-
"source": [
68-
"import torch\n",
69-
"\n",
70-
"input_ids = torch.tensor(tokens).view(-1, context_length)"
71-
],
82+
"source": "input_ids = torch.tensor(tokens).view(-1, context_length)",
7283
"outputs": [],
7384
"execution_count": null
7485
},
@@ -77,10 +88,6 @@
7788
"id": "7037fd75e2161382",
7889
"metadata": {},
7990
"source": [
80-
"from torch.utils.data import DataLoader, TensorDataset\n",
81-
"from torch.optim import Adam\n",
82-
"from torch.utils.data import random_split\n",
83-
"\n",
8491
"dataset = TensorDataset(input_ids)\n",
8592
"\n",
8693
"train_ratio = 0.8\n",
@@ -102,8 +109,6 @@
102109
"id": "a98b7baa064b8494",
103110
"metadata": {},
104111
"source": [
105-
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
106-
"\n",
107112
"model = GPTModel()\n",
108113
"state_dict = torch.load('transformed.pth', weights_only=True)\n",
109114
"\n",
@@ -128,8 +133,6 @@
128133
"id": "e2f5076894770740",
129134
"metadata": {},
130135
"source": [
131-
"from labml import tracker, experiment\n",
132-
"\n",
133136
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
134137
"criterion = torch.nn.CrossEntropyLoss()\n",
135138
"\n",
@@ -143,39 +146,38 @@
143146
" inputs = batch[0]\n",
144147
" inputs = inputs.to(device)\n",
145148
" labels = inputs.clone()\n",
146-
" \n",
149+
"\n",
147150
" outputs = model(inputs)\n",
148-
" \n",
151+
"\n",
149152
" shift_logits = outputs[..., :-1, :]\n",
150153
" shift_labels = labels[..., 1:]\n",
151-
" \n",
154+
"\n",
152155
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
153-
" \n",
156+
"\n",
154157
" optimizer.zero_grad()\n",
155158
" loss.backward()\n",
156159
" optimizer.step()\n",
157-
" \n",
160+
"\n",
158161
" tracker.save(step, {'loss': loss})\n",
159162
" step += 1\n",
160163
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
161-
" \n",
164+
"\n",
162165
" test_loss = 0\n",
163166
" for batch in test_dataloader:\n",
164167
" inputs = batch[0]\n",
165168
" inputs = inputs.to(device)\n",
166169
" labels = inputs.clone()\n",
167-
" \n",
170+
"\n",
168171
" outputs = model(inputs)\n",
169-
" \n",
172+
"\n",
170173
" shift_logits = outputs[..., :-1, :]\n",
171174
" shift_labels = labels[..., 1:]\n",
172-
" \n",
175+
"\n",
173176
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
174-
" \n",
177+
"\n",
175178
" test_loss += loss.item()\n",
176179
" test_loss /= len(test_dataloader)\n",
177180
" tracker.save(step, {'test_loss': test_loss})\n",
178-
" \n",
179181
"\n",
180182
"print(\"Training complete.\")"
181183
],
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM
3+
4+
5+
def transform_hf_model():
6+
model = AutoModelForCausalLM.from_pretrained("gpt2")
7+
8+
state_dict = model.state_dict()
9+
10+
mapping = {
11+
'transformer.wte.weight': 'token_embedding.weight',
12+
'transformer.wpe.weight': 'position_embedding.weight',
13+
'transformer.ln_f.weight': 'final_norm.weight',
14+
'transformer.ln_f.bias': 'final_norm.bias',
15+
'lm_head.weight': 'lm_head.weight'
16+
}
17+
18+
for i in range(12):
19+
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
20+
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
21+
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
22+
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
23+
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
24+
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
25+
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
26+
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
27+
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
28+
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
29+
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
30+
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
31+
32+
new_state_dict = {}
33+
for old_key, new_key in mapping.items():
34+
if old_key in state_dict:
35+
new_state_dict[new_key] = state_dict[old_key]
36+
37+
# transpose weight matrices of convo 1d layers to use linear layers instead
38+
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
39+
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
40+
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
41+
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
42+
43+
for layer in convo_layers:
44+
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
45+
46+
torch.save(new_state_dict, 'transformed.pth')
File renamed without changes.
File renamed without changes.

โ€Žlabml_nn/RWKV/experiment.py renamed to โ€Žlabml_nn/rwkv/experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import torch
55
import torch.nn as nn
6-
from labml_nn.RWKV.configs import RWKVConfigs
6+
from labml_nn.rwkv.configs import RWKVConfigs
77

8-
from labml_nn.RWKV import RWKV
9-
from labml_nn.RWKV import TimeMixing
8+
from labml_nn.rwkv import RWKV
9+
from labml_nn.rwkv import TimeMixing
1010
from labml import experiment
1111
from labml.configs import option
1212
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

โ€Žlabml_nn/transformers/LoRA/experiment.ipynb

Lines changed: 0 additions & 97 deletions
This file was deleted.

โ€Žlabml_nn/transformers/LoRA/load_hf.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
ย (0)