Skip to content

Commit 9748367

Browse files
committed
prepare to crossover entire transformer
1 parent d42b793 commit 9748367

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

nGPT_pytorch/evo.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,58 @@
1+
from __future__ import annotations
12
from copy import deepcopy
23

34
import torch
4-
from torch import cat, randperm
5+
from torch import cat, randperm, Tensor
56

67
from einops import rearrange
7-
from nGPT_pytorch.nGPT import FeedForward, Attention
8+
from nGPT_pytorch.nGPT import NormLinear, FeedForward, Attention
9+
10+
# helper functions
11+
12+
def exists(v):
13+
return v is not None
14+
15+
# cross over normlinear
16+
17+
@torch.no_grad()
18+
def cross_over_linear(
19+
parent1: NormLinear,
20+
parent2: NormLinear,
21+
parent1_indices: Tensor,
22+
parent2_indices: Tensor,
23+
child: NormLinear | None = None,
24+
dim: int = 0
25+
) -> NormLinear:
26+
27+
if not exists(child):
28+
child = deepcopy(parent1)
29+
30+
assert dim in {0, 1}
31+
assert parent1 == parent2
32+
33+
w1 = parent1.weight
34+
w2 = parent2.weight
35+
36+
if dim == 0:
37+
crossover_weight = cat((w1[parent1_indices], w2[parent2_indices]), dim = 0)
38+
else:
39+
crossover_weight = cat((w1[:, parent1_indices], w2[:, parent2_indices]), dim = 1)
40+
41+
child.weight.copy_(crossover_weight)
42+
return child
843

944
# breeding feedforwards
1045

1146
@torch.no_grad()
1247
def cross_over_feedforward(
1348
parent1: FeedForward,
14-
parent2: FeedForward
49+
parent2: FeedForward,
50+
child: FeedForward | None = None
1551
) -> FeedForward:
1652
assert parent1 == parent2
1753

18-
child = deepcopy(parent1)
54+
if not exists(child):
55+
child = deepcopy(parent1)
1956

2057
dim_inner = parent1.dim_inner
2158

@@ -60,15 +97,17 @@ def cross_over_feedforward(
6097
@torch.no_grad()
6198
def cross_over_attention(
6299
parent1: Attention,
63-
parent2: Attention
100+
parent2: Attention,
101+
child: Attention | None = None
64102
) -> Attention:
65103

66104
assert parent1 == parent2
67105

68106
heads = parent1.heads
69107
assert heads > 1
70108

71-
child = deepcopy(parent1)
109+
if not exists(child):
110+
child = deepcopy(parent1)
72111

73112
split_heads_first_dim = lambda t: rearrange(t, '(h d) ... -> h d ...', h = heads)
74113
split_heads_last_dim = lambda t: rearrange(t, 'e (h d) -> e h d', h = heads)

nGPT_pytorch/nGPT.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def __init__(
139139
groups = 1
140140
):
141141
super().__init__()
142+
self.dim = dim
143+
self.dim_out = dim_out
144+
142145
self.linear = nn.Linear(dim, dim_out, bias = False)
143146

144147
self.scale = groups ** -1
@@ -154,6 +157,13 @@ def __init__(
154157

155158
self.norm_weights_()
156159

160+
def __eq__(self, x):
161+
return (
162+
isinstance(x, NormLinear) and
163+
self.dim == x.dim and
164+
self.dim_out == x.dim_out
165+
)
166+
157167
@torch.no_grad()
158168
def norm_weights_(self):
159169
if self.parametrize:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.2.11"
3+
version = "0.2.12"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)