|
3 | 3 | import torch
|
4 | 4 | from torch import cat, randperm
|
5 | 5 |
|
6 |
| -from nGPT_pytorch.nGPT import FeedForward |
| 6 | +from einops import rearrange |
| 7 | +from nGPT_pytorch.nGPT import FeedForward, Attention |
| 8 | + |
| 9 | +# breeding feedforwards |
7 | 10 |
|
8 | 11 | @torch.no_grad()
|
9 |
| -def cross_over( |
| 12 | +def cross_over_feedforward( |
10 | 13 | parent1: FeedForward,
|
11 | 14 | parent2: FeedForward
|
12 | 15 | ) -> FeedForward:
|
| 16 | + assert parent1 == parent2 |
13 | 17 |
|
14 | 18 | child = deepcopy(parent1)
|
15 | 19 |
|
16 |
| - assert parent1.dim == parent2.dim and parent1.expand_factor == parent2.expand_factor |
17 | 20 | dim_inner = parent1.dim_inner
|
18 | 21 |
|
19 | 22 | parent1_w1 = parent1.to_hidden.weight
|
@@ -51,3 +54,62 @@ def cross_over(
|
51 | 54 | child_g_scale.copy_(cat((parent1_g_scale[parent1_indices], parent2_g_scale[parent2_indices])))
|
52 | 55 |
|
53 | 56 | return child
|
| 57 | + |
| 58 | +# breed attention |
| 59 | + |
| 60 | +@torch.no_grad() |
| 61 | +def cross_over_attention( |
| 62 | + parent1: Attention, |
| 63 | + parent2: Attention |
| 64 | +) -> Attention: |
| 65 | + |
| 66 | + assert parent1 == parent2 |
| 67 | + |
| 68 | + heads = parent1.heads |
| 69 | + assert heads > 1 |
| 70 | + |
| 71 | + child = deepcopy(parent1) |
| 72 | + |
| 73 | + split_heads_first_dim = lambda t: rearrange(t, '(h d) ... -> h d ...', h = heads) |
| 74 | + split_heads_last_dim = lambda t: rearrange(t, 'e (h d) -> e h d', h = heads) |
| 75 | + |
| 76 | + flatten_first = lambda t: rearrange(t, 'h d ... -> (h d) ...') |
| 77 | + flatten_last = lambda t: rearrange(t, 'e h d -> e (h d)') |
| 78 | + |
| 79 | + parent1_q = split_heads_first_dim(parent1.to_q.weight) |
| 80 | + parent2_q = split_heads_first_dim(parent2.to_q.weight) |
| 81 | + child_q = child.to_q.weight |
| 82 | + |
| 83 | + parent1_k = split_heads_first_dim(parent1.to_k.weight) |
| 84 | + parent2_k = split_heads_first_dim(parent2.to_k.weight) |
| 85 | + child_k = child.to_k.weight |
| 86 | + |
| 87 | + parent1_v = split_heads_first_dim(parent1.to_v.weight) |
| 88 | + parent2_v = split_heads_first_dim(parent2.to_v.weight) |
| 89 | + child_v = child.to_v.weight |
| 90 | + |
| 91 | + parent1_o = split_heads_last_dim(parent1.to_out.weight) |
| 92 | + parent2_o = split_heads_last_dim(parent2.to_out.weight) |
| 93 | + child_o = child.to_out.weight |
| 94 | + |
| 95 | + parent1_qk_scale = split_heads_first_dim(parent1.qk_scale.scale) |
| 96 | + parent2_qk_scale = split_heads_first_dim(parent2.qk_scale.scale) |
| 97 | + child_qk_scale = child.qk_scale.scale |
| 98 | + |
| 99 | + # randomly select heads from parents1 and parents2 for crossover |
| 100 | + |
| 101 | + midpoint = heads // 2 |
| 102 | + rand_indices = randperm(heads) |
| 103 | + |
| 104 | + parent1_indices, parent2_indices = rand_indices[:midpoint], rand_indices[midpoint:] |
| 105 | + |
| 106 | + # select out the correct parameters for attention heads from parent 1 and 2 |
| 107 | + |
| 108 | + child_q.copy_(flatten_first(cat((parent1_q[parent1_indices], parent2_q[parent2_indices]), dim = 0))) |
| 109 | + child_k.copy_(flatten_first(cat((parent1_k[parent1_indices], parent2_k[parent2_indices]), dim = 0))) |
| 110 | + child_v.copy_(flatten_first(cat((parent1_v[parent1_indices], parent2_v[parent2_indices]), dim = 0))) |
| 111 | + child_qk_scale.copy_(flatten_first(cat((parent1_qk_scale[parent1_indices], parent2_qk_scale[parent2_indices]), dim = 0))) |
| 112 | + |
| 113 | + child_o.copy_(flatten_last(cat((parent1_o[:, parent1_indices], parent2_o[:, parent2_indices]), dim = 1))) |
| 114 | + |
| 115 | + return child |
0 commit comments