Skip to content

Commit 0d0d729

Browse files
committed
prepare for breeding the nGPT feedforwards
1 parent 5208aad commit 0d0d729

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

nGPT_pytorch/evo.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from copy import deepcopy
2+
3+
import torch
4+
from torch import cat
5+
6+
from nGPT_pytorch.nGPT import FeedForward
7+
8+
@torch.no_grad()
9+
def cross_over(
10+
parent1: FeedForward,
11+
parent2: FeedForward
12+
) -> FeedForward:
13+
14+
child = deepcopy(parent1)
15+
16+
dim = parent1.dim
17+
assert parent1.dim == parent2.dim and parent1.expand_factor == parent2.expand_factor
18+
19+
parent1_w1 = parent1.to_hidden.weight
20+
parent2_w1 = parent2.to_hidden.weight
21+
child_w1 = child.to_hidden.weight
22+
23+
parent1_gate = parent1.to_gate.weight
24+
parent2_gate = parent2.to_gate.weight
25+
child_gate = child.to_gate.weight
26+
27+
parent1_w2 = parent1.to_out.weight
28+
parent2_w2 = parent2.to_out.weight
29+
child_w2 = child.to_out.weight
30+
31+
midpoint = dim // 2
32+
rand_indices = torch.randperm(dim)
33+
34+
# randomly select vectors from the feedforward weight matrices from both parents to constitute the child
35+
36+
parent1_indices, parent2_indices = rand_indices[:midpoint], rand_indices[midpoint:]
37+
38+
child_w1.copy_(cat((parent1_w1[:, parent1_indices], parent2_w1[:, parent2_indices]), dim = 1))
39+
child_gate.copy_(cat((parent1_gate[:, parent1_indices], parent2_gate[:, parent2_indices]), dim = 1))
40+
child_w2.copy_(cat((parent1_w2[parent1_indices, :], parent2_w2[parent2_indices, :]), dim = 0))
41+
42+
return child

nGPT_pytorch/nGPT.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def __init__(
307307
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres)
308308

309309
self.dim = dim
310+
self.expand_factor = expand_factor
311+
310312
dim_inner = int(dim * expand_factor * 2 / 3)
311313

312314
self.to_hidden = NormLinear_(dim, dim_inner)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.2.7"
3+
version = "0.2.8"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)