|
| 1 | +from __future__ import annotations |
1 | 2 | from copy import deepcopy
|
2 | 3 |
|
3 | 4 | import torch
|
4 |
| -from torch import cat, randperm |
| 5 | +from torch import cat, randperm, Tensor |
5 | 6 |
|
6 | 7 | from einops import rearrange
|
7 |
| -from nGPT_pytorch.nGPT import FeedForward, Attention |
| 8 | +from nGPT_pytorch.nGPT import NormLinear, FeedForward, Attention |
| 9 | + |
| 10 | +# helper functions |
| 11 | + |
| 12 | +def exists(v): |
| 13 | + return v is not None |
| 14 | + |
| 15 | +# cross over normlinear |
| 16 | + |
| 17 | +@torch.no_grad() |
| 18 | +def cross_over_linear( |
| 19 | + parent1: NormLinear, |
| 20 | + parent2: NormLinear, |
| 21 | + parent1_indices: Tensor, |
| 22 | + parent2_indices: Tensor, |
| 23 | + child: NormLinear | None = None, |
| 24 | + dim: int = 0 |
| 25 | +) -> NormLinear: |
| 26 | + |
| 27 | + if not exists(child): |
| 28 | + child = deepcopy(parent1) |
| 29 | + |
| 30 | + assert dim in {0, 1} |
| 31 | + assert parent1 == parent2 |
| 32 | + |
| 33 | + w1 = parent1.weight |
| 34 | + w2 = parent2.weight |
| 35 | + |
| 36 | + if dim == 0: |
| 37 | + crossover_weight = cat((w1[parent1_indices], w2[parent2_indices]), dim = 0) |
| 38 | + else: |
| 39 | + crossover_weight = cat((w1[:, parent1_indices], w2[:, parent2_indices]), dim = 1) |
| 40 | + |
| 41 | + child.weight.copy_(crossover_weight) |
| 42 | + return child |
8 | 43 |
|
9 | 44 | # breeding feedforwards
|
10 | 45 |
|
11 | 46 | @torch.no_grad()
|
12 | 47 | def cross_over_feedforward(
|
13 | 48 | parent1: FeedForward,
|
14 |
| - parent2: FeedForward |
| 49 | + parent2: FeedForward, |
| 50 | + child: FeedForward | None = None |
15 | 51 | ) -> FeedForward:
|
16 | 52 | assert parent1 == parent2
|
17 | 53 |
|
18 |
| - child = deepcopy(parent1) |
| 54 | + if not exists(child): |
| 55 | + child = deepcopy(parent1) |
19 | 56 |
|
20 | 57 | dim_inner = parent1.dim_inner
|
21 | 58 |
|
@@ -60,15 +97,17 @@ def cross_over_feedforward(
|
60 | 97 | @torch.no_grad()
|
61 | 98 | def cross_over_attention(
|
62 | 99 | parent1: Attention,
|
63 |
| - parent2: Attention |
| 100 | + parent2: Attention, |
| 101 | + child: Attention | None = None |
64 | 102 | ) -> Attention:
|
65 | 103 |
|
66 | 104 | assert parent1 == parent2
|
67 | 105 |
|
68 | 106 | heads = parent1.heads
|
69 | 107 | assert heads > 1
|
70 | 108 |
|
71 |
| - child = deepcopy(parent1) |
| 109 | + if not exists(child): |
| 110 | + child = deepcopy(parent1) |
72 | 111 |
|
73 | 112 | split_heads_first_dim = lambda t: rearrange(t, '(h d) ... -> h d ...', h = heads)
|
74 | 113 | split_heads_last_dim = lambda t: rearrange(t, 'e (h d) -> e h d', h = heads)
|
|
0 commit comments