|
| 1 | +# coding=utf-8 |
| 2 | +"""PyTorch RWKV6 model.(native PyTorch version)""" |
| 3 | +""" |
| 4 | +author: @Zhiyuan Li |
| 5 | +email: |
| 6 | +date: 2024-07-22 |
| 7 | +""" |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import Dict, Iterable, List, Optional, Tuple |
| 10 | + |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +from transformers import RwkvConfig |
| 15 | +from vllm.config import LoRAConfig, CacheConfig, SchedulerConfig |
| 16 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 17 | + QuantizationConfig) |
| 18 | +from vllm.model_executor.layers.sampler import Sampler |
| 19 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 20 | + VocabParallelEmbedding) |
| 21 | +from vllm.attention.backends.abstract import AttentionMetadata |
| 22 | +from vllm.distributed import (get_tensor_model_parallel_rank, |
| 23 | + get_tensor_model_parallel_world_size) |
| 24 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 25 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 26 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 27 | + MergedColumnParallelLinear, |
| 28 | + RowParallelLinear) |
| 29 | +from vllm.model_executor.models.interfaces import HasInnerState |
| 30 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 31 | +from vllm.sequence import IntermediateTensors, SamplerOutput |
| 32 | +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, |
| 33 | + _get_graph_batch_size) |
| 34 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 35 | + |
| 36 | + |
| 37 | +MyModule = torch.jit.ScriptModule |
| 38 | +MyFunction = torch.jit.script_method |
| 39 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class RwkvCacheParams: |
| 43 | + is_prompt: bool = False |
| 44 | + ssm_state: torch.Tensor = torch.Tensor() |
| 45 | + |
| 46 | + |
| 47 | +class Rwkv_Block(MyModule): |
| 48 | + def __init__(self, block_w: dict, hidden_size: int, n_head: int): |
| 49 | + super().__init__() |
| 50 | + self.hidden_size = hidden_size |
| 51 | + self.n_head = n_head |
| 52 | + self.head_size = hidden_size // n_head |
| 53 | + |
| 54 | + self.ln1 = nn.LayerNorm(hidden_size) |
| 55 | + self.ln1.weight = nn.Parameter(block_w['ln1.weight']) |
| 56 | + self.ln1.bias = nn.Parameter(block_w['ln1.bias']) |
| 57 | + self.ln2 = nn.LayerNorm(hidden_size) |
| 58 | + self.ln2.weight = nn.Parameter(block_w['ln2.weight']) |
| 59 | + self.ln2.bias = nn.Parameter(block_w['ln2.bias']) |
| 60 | + |
| 61 | + |
| 62 | + self.silu = nn.SiLU(inplace=False) |
| 63 | + |
| 64 | + self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x']) |
| 65 | + self.att_time_maa_w = nn.Parameter(block_w['att.time_maa_w']) |
| 66 | + self.att_time_maa_k = nn.Parameter(block_w['att.time_maa_k']) |
| 67 | + self.att_time_maa_v = nn.Parameter(block_w['att.time_maa_v']) |
| 68 | + self.att_time_maa_r = nn.Parameter(block_w['att.time_maa_r']) |
| 69 | + self.att_time_maa_g = nn.Parameter(block_w['att.time_maa_g']) |
| 70 | + self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1']) |
| 71 | + self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2']) |
| 72 | + self.att_time_decay = nn.Parameter(block_w['att.time_decay']) |
| 73 | + self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1']) |
| 74 | + self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2']) |
| 75 | + self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa']) |
| 76 | + self.att_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 77 | + self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight']) |
| 78 | + self.att_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 79 | + self.att_key.weight = nn.Parameter(block_w['att.key.weight']) |
| 80 | + self.att_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 81 | + self.att_value.weight = nn.Parameter(block_w['att.value.weight']) |
| 82 | + self.att_output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 83 | + self.att_output.weight = nn.Parameter(block_w['att.output.weight']) |
| 84 | + self.att_gate = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 85 | + self.att_gate.weight = nn.Parameter(block_w['att.gate.weight']) |
| 86 | + |
| 87 | + |
| 88 | + self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=hidden_size, eps=1e-5, affine=True) |
| 89 | + self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight']) |
| 90 | + self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias']) |
| 91 | + |
| 92 | + self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k']) |
| 93 | + self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r']) |
| 94 | + self.ffn_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 95 | + self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight']) |
| 96 | + self.ffn_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 97 | + self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight']) |
| 98 | + self.ffn_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| 99 | + self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight']) |
| 100 | + |
| 101 | + @MyFunction |
| 102 | + def channel_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 103 | + i0 = (2 + self.head_size) * i + 0 |
| 104 | + sx = state[:, i0] - x |
| 105 | + state[:, i0] = x |
| 106 | + xk = x + sx * self.ffn_time_maa_k |
| 107 | + xr = x + sx * self.ffn_time_maa_r |
| 108 | + r = torch.sigmoid(self.ffn_receptance(xr)) |
| 109 | + k = torch.relu(self.ffn_key(xk)).pow(2) |
| 110 | + output = r * self.ffn_value(k) |
| 111 | + return output, state |
| 112 | + |
| 113 | + @MyFunction |
| 114 | + def channel_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 115 | + i0 = (2 + self.head_size) * i + 0 |
| 116 | + |
| 117 | + sx_lerp = torch.empty_like(x) |
| 118 | + sx_lerp[:, 0] = state[:, i0] - x[:, 0] |
| 119 | + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] |
| 120 | + |
| 121 | + state[:, i0] = x[:, -1] |
| 122 | + |
| 123 | + xk = x + sx_lerp * self.ffn_time_maa_k |
| 124 | + xr = x + sx_lerp * self.ffn_time_maa_r |
| 125 | + |
| 126 | + r = torch.sigmoid(self.ffn_receptance(xr)) # [Batch, L, hiddle_size] |
| 127 | + k = torch.relu(self.ffn_key(xk)).pow(2) |
| 128 | + |
| 129 | + output = r * self.ffn_value(k) |
| 130 | + return output, state |
| 131 | + |
| 132 | + def time_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 133 | + batch_size, H, S = x.size(0), self.n_head, self.head_size |
| 134 | + x, state, g = self.time_mixing_jit(x, state, i, batch_size, H, S) |
| 135 | + |
| 136 | + x = self.time_mixing_jit2(x, g) |
| 137 | + |
| 138 | + return x, state |
| 139 | + |
| 140 | + @MyFunction |
| 141 | + def time_mixing_jit(self, x: torch.Tensor, state: torch.Tensor, i: int, |
| 142 | + batch_size: int, H: int, S: int): |
| 143 | + i1 = (2 + S) * i + 1 # i is the block number |
| 144 | + |
| 145 | + sx = state[:, i1] - x |
| 146 | + state[:, i1] = x # Information is compressed to position 1 on each layer |
| 147 | + |
| 148 | + xxx = x + sx * self.att_time_maa_x |
| 149 | + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1) |
| 150 | + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1) |
| 151 | + mw, mk, mv, mr, mg = xxx.unbind(dim=1) |
| 152 | + |
| 153 | + xw = x + sx * (self.att_time_maa_w + mw) |
| 154 | + xk = x + sx * (self.att_time_maa_k + mk) |
| 155 | + xv = x + sx * (self.att_time_maa_v + mv) |
| 156 | + xr = x + sx * (self.att_time_maa_r + mr) |
| 157 | + xg = x + sx * (self.att_time_maa_g + mg) |
| 158 | + |
| 159 | + # calculate w, r, k, v, g |
| 160 | + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) |
| 161 | + w = -torch.exp(w.view(batch_size, H, S, 1)) |
| 162 | + |
| 163 | + r = self.att_receptance(xr).view(batch_size, H, 1, S) |
| 164 | + k = self.att_key(xk).view(batch_size, H, S, 1) |
| 165 | + v = self.att_value(xv).view(batch_size, H, 1, S) |
| 166 | + g = self.silu(self.att_gate(xg)) |
| 167 | + |
| 168 | + # Update state using attention mechanism |
| 169 | + s = state[:, (2+S)*i+2:(2+S)*(i+1), :].view(batch_size, H, S, S) |
| 170 | + a = k @ v |
| 171 | + x = r @ (self.att_time_faaaa * a + s) |
| 172 | + s = a + torch.exp(w) * s |
| 173 | + # Update the attention parameters of the i-th layer STATE |
| 174 | + state[:, (2+S)*i+2:(2+S)*(i+1), :] = s.view(batch_size, S, -1) |
| 175 | + return x, state, g |
| 176 | + |
| 177 | + @MyFunction |
| 178 | + def time_mixing_jit2(self, x:torch.Tensor, g): |
| 179 | + return self.att_output(self.att_group_norm(x.flatten(start_dim=1)) * g) |
| 180 | + |
| 181 | + def time_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 182 | + batch_size, L, H, S = x.size(0), x.size(1), self.n_head, self.head_size |
| 183 | + x, state, g = self.time_mixing_parallel_jit1(x, state, i, batch_size, L, H, S) |
| 184 | + |
| 185 | + x = self.time_mixing_parallel_jit2(x, g, batch_size, L) |
| 186 | + |
| 187 | + return x, state |
| 188 | + |
| 189 | + @MyFunction |
| 190 | + def time_mixing_parallel_jit1(self, x: torch.Tensor, state: torch.Tensor, i: int, |
| 191 | + batch_size: int, L: int, H: int, S: int): |
| 192 | + i1 = (2 + S) * i + 1 |
| 193 | + sx_lerp = torch.empty_like(x) |
| 194 | + sx_lerp[:, 0] = state[:, i1] - x[:, 0] |
| 195 | + |
| 196 | + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] |
| 197 | + |
| 198 | + state[:, i1] = x[:, -1] |
| 199 | + |
| 200 | + xxx = x + sx_lerp * self.att_time_maa_x # torch.Size([B, L, hiddle_size]) |
| 201 | + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, L, 5, 1, -1) # att_time_maa_w1: [hiddle_size, 160] |
| 202 | + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, L, 5, -1) # [Batch, L, 5, hiddle_size] |
| 203 | + |
| 204 | + mw, mk, mv, mr, mg = xxx.unbind(dim=2) # [10, 100, hiddle_size] |
| 205 | + |
| 206 | + xw = x + sx_lerp * (self.att_time_maa_w + mw) # torch.Size([B, L, hiddle_size]) |
| 207 | + xk = x + sx_lerp * (self.att_time_maa_k + mk) |
| 208 | + xv = x + sx_lerp * (self.att_time_maa_v + mv) |
| 209 | + xr = x + sx_lerp * (self.att_time_maa_r + mr) |
| 210 | + xg = x + sx_lerp * (self.att_time_maa_g + mg) |
| 211 | + |
| 212 | + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) |
| 213 | + w = -torch.exp(w.view(batch_size, L, H, S, 1)) |
| 214 | + |
| 215 | + r = self.att_receptance(xr).view(batch_size, L, H, 1, S) |
| 216 | + k = self.att_key(xk).view(batch_size, L, H, S, 1) |
| 217 | + v = self.att_value(xv).view(batch_size, L, H, 1, S) |
| 218 | + g = self.silu(self.att_gate(xg)) # [10, 100, hiddle_size] |
| 219 | + # TODO, apply kernel here, cuda or fla(triton) |
| 220 | + |
| 221 | + |
| 222 | + w = torch.exp(w) |
| 223 | + s = state[:, (2+S)*i+2:(2+S)*(i+1)].view(batch_size, H, S, S) |
| 224 | + a = k @ v # a: [batch_size, L, H, S, S] |
| 225 | + |
| 226 | + state_s = torch.zeros(batch_size, L, H, S, S, dtype=x.dtype, device=x.device) |
| 227 | + state_s[:, 0] = s |
| 228 | + |
| 229 | + for l in range(L-1): |
| 230 | + s = a[:, l] + w[:, l] * s |
| 231 | + state_s[:, l+1] = s |
| 232 | + s = a[:, -1] + w[:, -1] * s |
| 233 | + |
| 234 | + state[:, (2+S)*i+2:(2+S)*(i+1)] = s.view(batch_size, S, -1) |
| 235 | + |
| 236 | + x = r @ (self.att_time_faaaa * a + state_s) |
| 237 | + return x, state, g |
| 238 | + |
| 239 | + @MyFunction |
| 240 | + def time_mixing_parallel_jit2(self, x: torch.Tensor, g: torch.Tensor, batch_size: int, L:int): |
| 241 | + return self.att_output(self.att_group_norm(x.flatten(start_dim=2).view(batch_size * L, -1)).view(batch_size, L, -1) * g) |
| 242 | + |
| 243 | + @torch.no_grad() |
| 244 | + def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 245 | + x_time, state = self.time_mixing(self.ln1(x), state, i) |
| 246 | + x = x + x_time |
| 247 | + x_channel, state = self.channel_mixing(self.ln2(x), state, i) |
| 248 | + x = x + x_channel |
| 249 | + |
| 250 | + return x, state |
| 251 | + |
| 252 | + @torch.no_grad() |
| 253 | + def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: |
| 254 | + # Time mixing |
| 255 | + x_time, state = self.time_mixing_parallel(self.ln1(x), state, i) |
| 256 | + x = x + x_time |
| 257 | + |
| 258 | + # Channel mixing |
| 259 | + x_channel, state = self.channel_mixing_parallel(self.ln2(x), state, i) |
| 260 | + x = x + x_channel |
| 261 | + |
| 262 | + return x, state |
| 263 | + |
| 264 | +class RwkvModel(MyModule): |
| 265 | + def __init__(self, args: dict): |
| 266 | + super().__init__() |
| 267 | + self.args = args |
| 268 | + self.load_params() |
| 269 | + self.eval() |
| 270 | + |
| 271 | + |
| 272 | + |
| 273 | + def load_params(self, load_from_file: bool = True, w: dict = None): |
| 274 | + # TODO: vllm |
| 275 | + if load_from_file: |
| 276 | + if not self.args['MODEL_NAME'].endswith('.pth'): |
| 277 | + self.args['MODEL_NAME'] += '.pth' |
| 278 | + w = torch.load(self.args['MODEL_NAME'], map_location="cpu") |
| 279 | + else: |
| 280 | + assert w is not None |
| 281 | + |
| 282 | + self.num_layer = 0 |
| 283 | + for k in w.keys(): |
| 284 | + if '.time_' in k: w[k] = w[k].squeeze() |
| 285 | + if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) |
| 286 | + if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1])) |
| 287 | + self.num_layer += 1 |
| 288 | + |
| 289 | + self.n_head = w['blocks.0.att.time_faaaa'].shape[0] |
| 290 | + self.n_embd = w['blocks.0.ln1.weight'].shape[0] |
| 291 | + self.head_size = self.n_embd // self.n_head |
| 292 | + self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd] |
| 293 | + |
| 294 | + self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True) |
| 295 | + |
| 296 | + |
| 297 | + self.ln0 = nn.LayerNorm(self.n_embd) |
| 298 | + self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight']) |
| 299 | + self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias']) |
| 300 | + |
| 301 | + |
| 302 | + self.blocks = nn.ModuleList() |
| 303 | + |
| 304 | + for i in range(self.num_layer): |
| 305 | + block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k} |
| 306 | + self.blocks.append(Rwkv_Block(block_w, self.n_embd, self.n_head, self.args)) |
| 307 | + |
| 308 | + |
| 309 | + self.ln_out = nn.LayerNorm(self.n_embd) |
| 310 | + self.ln_out.weight = nn.Parameter(w['ln_out.weight']) |
| 311 | + self.ln_out.bias = nn.Parameter(w['ln_out.bias']) |
| 312 | + |
| 313 | + |
| 314 | + self.head = nn.Linear(self.n_embd, self.args['vocab_size'], bias=False) |
| 315 | + self.head.weight = nn.Parameter(w['head.weight']) |
| 316 | + |
| 317 | + |
| 318 | + @torch.no_grad() |
| 319 | + def forward(self, |
| 320 | + input_ids: torch.Tensor, |
| 321 | + state: torch.Tensor, |
| 322 | + attn_metadata: AttentionMetadata, |
| 323 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 324 | + |
| 325 | + x = self.forward_jit1(input_ids) |
| 326 | + |
| 327 | + if attn_metadata.prefill_metadata is not None: |
| 328 | + # Prefill phase |
| 329 | + for i, block in enumerate(self.blocks): |
| 330 | + x, state = block.forward_parallel(x, state, i) |
| 331 | + else: |
| 332 | + # Decoding phase |
| 333 | + for i, block in enumerate(self.blocks): |
| 334 | + x, state = block(x, state, i) |
| 335 | + |
| 336 | + x = self.forward_jit2(x) |
| 337 | + |
| 338 | + return x, state |
| 339 | + |
| 340 | + |
| 341 | + @MyFunction |
| 342 | + def forward_jit1(self, token: torch.Tensor) -> torch.Tensor: |
| 343 | + return self.ln0(self.emb(token)) |
| 344 | + |
| 345 | + @MyFunction |
| 346 | + def forward_jit2(self, x: torch.Tensor) -> torch.Tensor: |
| 347 | + return self.head(self.ln_out(x)) |
| 348 | + |
| 349 | + @torch.no_grad() |
| 350 | + def forward_parallel_slices(self, |
| 351 | + input_ids: torch.Tensor, |
| 352 | + state: torch.Tensor, |
| 353 | + attn_metadata: AttentionMetadata, |
| 354 | + slice_len: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: |
| 355 | + """ |
| 356 | + Prefill forward with chunks of the RWKV6 model. |
| 357 | + Args: |
| 358 | + x (torch.Tensor): Input tensor, shape [Batch, N_embd]. |
| 359 | + state (torch.Tensor): Hidden state tensor, shape [Batch, State Size, N_embd]. |
| 360 | + i (int): Time index. |
| 361 | + Returns: |
| 362 | + torch.Tensor: Forward pass result tensor, shape same as input x. |
| 363 | + """ |
| 364 | + # FIXME! |
| 365 | + data_len = input_ids.shape[1] |
| 366 | + for i in range((data_len-2)//slice_len+1): |
| 367 | + start = i*slice_len |
| 368 | + end = min((i+1)*slice_len, data_len) |
| 369 | + input_ids_ith_chunk = input_ids[:, start:end] |
| 370 | + token_out, state = self.forward(input_ids_ith_chunk, state, attn_metadata) |
| 371 | + |
| 372 | + return token_out, state |
| 373 | + |
| 374 | + def init_state(self, batch_size: int) -> torch.Tensor: |
| 375 | + state = torch.zeros(batch_size, self.state_size[0], self.state_size[1]) |
| 376 | + return state |
| 377 | + |
0 commit comments