Skip to content

Commit d42b793

Browse files
committed
evolution will be part of the final equation.
1 parent 9acfe26 commit d42b793

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

nGPT_pytorch/evo.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import torch
44
from torch import cat, randperm
55

6-
from nGPT_pytorch.nGPT import FeedForward
6+
from einops import rearrange
7+
from nGPT_pytorch.nGPT import FeedForward, Attention
8+
9+
# breeding feedforwards
710

811
@torch.no_grad()
9-
def cross_over(
12+
def cross_over_feedforward(
1013
parent1: FeedForward,
1114
parent2: FeedForward
1215
) -> FeedForward:
16+
assert parent1 == parent2
1317

1418
child = deepcopy(parent1)
1519

16-
assert parent1.dim == parent2.dim and parent1.expand_factor == parent2.expand_factor
1720
dim_inner = parent1.dim_inner
1821

1922
parent1_w1 = parent1.to_hidden.weight
@@ -51,3 +54,62 @@ def cross_over(
5154
child_g_scale.copy_(cat((parent1_g_scale[parent1_indices], parent2_g_scale[parent2_indices])))
5255

5356
return child
57+
58+
# breed attention
59+
60+
@torch.no_grad()
61+
def cross_over_attention(
62+
parent1: Attention,
63+
parent2: Attention
64+
) -> Attention:
65+
66+
assert parent1 == parent2
67+
68+
heads = parent1.heads
69+
assert heads > 1
70+
71+
child = deepcopy(parent1)
72+
73+
split_heads_first_dim = lambda t: rearrange(t, '(h d) ... -> h d ...', h = heads)
74+
split_heads_last_dim = lambda t: rearrange(t, 'e (h d) -> e h d', h = heads)
75+
76+
flatten_first = lambda t: rearrange(t, 'h d ... -> (h d) ...')
77+
flatten_last = lambda t: rearrange(t, 'e h d -> e (h d)')
78+
79+
parent1_q = split_heads_first_dim(parent1.to_q.weight)
80+
parent2_q = split_heads_first_dim(parent2.to_q.weight)
81+
child_q = child.to_q.weight
82+
83+
parent1_k = split_heads_first_dim(parent1.to_k.weight)
84+
parent2_k = split_heads_first_dim(parent2.to_k.weight)
85+
child_k = child.to_k.weight
86+
87+
parent1_v = split_heads_first_dim(parent1.to_v.weight)
88+
parent2_v = split_heads_first_dim(parent2.to_v.weight)
89+
child_v = child.to_v.weight
90+
91+
parent1_o = split_heads_last_dim(parent1.to_out.weight)
92+
parent2_o = split_heads_last_dim(parent2.to_out.weight)
93+
child_o = child.to_out.weight
94+
95+
parent1_qk_scale = split_heads_first_dim(parent1.qk_scale.scale)
96+
parent2_qk_scale = split_heads_first_dim(parent2.qk_scale.scale)
97+
child_qk_scale = child.qk_scale.scale
98+
99+
# randomly select heads from parents1 and parents2 for crossover
100+
101+
midpoint = heads // 2
102+
rand_indices = randperm(heads)
103+
104+
parent1_indices, parent2_indices = rand_indices[:midpoint], rand_indices[midpoint:]
105+
106+
# select out the correct parameters for attention heads from parent 1 and 2
107+
108+
child_q.copy_(flatten_first(cat((parent1_q[parent1_indices], parent2_q[parent2_indices]), dim = 0)))
109+
child_k.copy_(flatten_first(cat((parent1_k[parent1_indices], parent2_k[parent2_indices]), dim = 0)))
110+
child_v.copy_(flatten_first(cat((parent1_v[parent1_indices], parent2_v[parent2_indices]), dim = 0)))
111+
child_qk_scale.copy_(flatten_first(cat((parent1_qk_scale[parent1_indices], parent2_qk_scale[parent2_indices]), dim = 0)))
112+
113+
child_o.copy_(flatten_last(cat((parent1_o[:, parent1_indices], parent2_o[:, parent2_indices]), dim = 1)))
114+
115+
return child

nGPT_pytorch/nGPT.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def __init__(
195195
num_hyperspheres = 1,
196196
):
197197
super().__init__()
198+
self.dim = dim
199+
self.dim_head = dim_head
200+
198201
self.heads = heads
199202
self.causal = causal
200203

@@ -225,6 +228,9 @@ def __init__(
225228

226229
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
227230

231+
def __eq__(x, y):
232+
return x.dim == y.dim and x.heads == y.heads and x.dim_head == y.dim_head
233+
228234
def forward(
229235
self,
230236
x,
@@ -321,6 +327,9 @@ def __init__(
321327

322328
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
323329

330+
def __eq__(x, y):
331+
return x.dim == y.dim and x.expand_factor == y.expand_factor
332+
324333
def forward(self, x):
325334
hidden, gate = self.to_hidden(x), self.to_gate(x)
326335

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.2.10"
3+
version = "0.2.11"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)