Skip to content

Commit 1b7b596

Browse files
committed
add model for rwkv6( No adaptation has been implemented
1 parent c319a21 commit 1b7b596

File tree

1 file changed

+377
-0
lines changed

1 file changed

+377
-0
lines changed

vllm/model_executor/models/rwkv_6.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
# coding=utf-8
2+
"""PyTorch RWKV6 model.(native PyTorch version)"""
3+
"""
4+
author: @Zhiyuan Li
5+
email:
6+
date: 2024-07-22
7+
"""
8+
from dataclasses import dataclass
9+
from typing import Dict, Iterable, List, Optional, Tuple
10+
11+
12+
import torch
13+
import torch.nn as nn
14+
from transformers import RwkvConfig
15+
from vllm.config import LoRAConfig, CacheConfig, SchedulerConfig
16+
from vllm.model_executor.layers.quantization.base_config import (
17+
QuantizationConfig)
18+
from vllm.model_executor.layers.sampler import Sampler
19+
from vllm.model_executor.layers.vocab_parallel_embedding import (
20+
VocabParallelEmbedding)
21+
from vllm.attention.backends.abstract import AttentionMetadata
22+
from vllm.distributed import (get_tensor_model_parallel_rank,
23+
get_tensor_model_parallel_world_size)
24+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25+
from vllm.model_executor.layers.layernorm import RMSNorm
26+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
27+
MergedColumnParallelLinear,
28+
RowParallelLinear)
29+
from vllm.model_executor.models.interfaces import HasInnerState
30+
from vllm.model_executor.sampling_metadata import SamplingMetadata
31+
from vllm.sequence import IntermediateTensors, SamplerOutput
32+
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
33+
_get_graph_batch_size)
34+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35+
36+
37+
MyModule = torch.jit.ScriptModule
38+
MyFunction = torch.jit.script_method
39+
KVCache = Tuple[torch.Tensor, torch.Tensor]
40+
41+
@dataclass
42+
class RwkvCacheParams:
43+
is_prompt: bool = False
44+
ssm_state: torch.Tensor = torch.Tensor()
45+
46+
47+
class Rwkv_Block(MyModule):
48+
def __init__(self, block_w: dict, hidden_size: int, n_head: int):
49+
super().__init__()
50+
self.hidden_size = hidden_size
51+
self.n_head = n_head
52+
self.head_size = hidden_size // n_head
53+
54+
self.ln1 = nn.LayerNorm(hidden_size)
55+
self.ln1.weight = nn.Parameter(block_w['ln1.weight'])
56+
self.ln1.bias = nn.Parameter(block_w['ln1.bias'])
57+
self.ln2 = nn.LayerNorm(hidden_size)
58+
self.ln2.weight = nn.Parameter(block_w['ln2.weight'])
59+
self.ln2.bias = nn.Parameter(block_w['ln2.bias'])
60+
61+
62+
self.silu = nn.SiLU(inplace=False)
63+
64+
self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x'])
65+
self.att_time_maa_w = nn.Parameter(block_w['att.time_maa_w'])
66+
self.att_time_maa_k = nn.Parameter(block_w['att.time_maa_k'])
67+
self.att_time_maa_v = nn.Parameter(block_w['att.time_maa_v'])
68+
self.att_time_maa_r = nn.Parameter(block_w['att.time_maa_r'])
69+
self.att_time_maa_g = nn.Parameter(block_w['att.time_maa_g'])
70+
self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1'])
71+
self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2'])
72+
self.att_time_decay = nn.Parameter(block_w['att.time_decay'])
73+
self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1'])
74+
self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2'])
75+
self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa'])
76+
self.att_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
77+
self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight'])
78+
self.att_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
79+
self.att_key.weight = nn.Parameter(block_w['att.key.weight'])
80+
self.att_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
81+
self.att_value.weight = nn.Parameter(block_w['att.value.weight'])
82+
self.att_output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
83+
self.att_output.weight = nn.Parameter(block_w['att.output.weight'])
84+
self.att_gate = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
85+
self.att_gate.weight = nn.Parameter(block_w['att.gate.weight'])
86+
87+
88+
self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=hidden_size, eps=1e-5, affine=True)
89+
self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight'])
90+
self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias'])
91+
92+
self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k'])
93+
self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r'])
94+
self.ffn_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
95+
self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight'])
96+
self.ffn_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
97+
self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight'])
98+
self.ffn_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
99+
self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight'])
100+
101+
@MyFunction
102+
def channel_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
103+
i0 = (2 + self.head_size) * i + 0
104+
sx = state[:, i0] - x
105+
state[:, i0] = x
106+
xk = x + sx * self.ffn_time_maa_k
107+
xr = x + sx * self.ffn_time_maa_r
108+
r = torch.sigmoid(self.ffn_receptance(xr))
109+
k = torch.relu(self.ffn_key(xk)).pow(2)
110+
output = r * self.ffn_value(k)
111+
return output, state
112+
113+
@MyFunction
114+
def channel_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
115+
i0 = (2 + self.head_size) * i + 0
116+
117+
sx_lerp = torch.empty_like(x)
118+
sx_lerp[:, 0] = state[:, i0] - x[:, 0]
119+
sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:]
120+
121+
state[:, i0] = x[:, -1]
122+
123+
xk = x + sx_lerp * self.ffn_time_maa_k
124+
xr = x + sx_lerp * self.ffn_time_maa_r
125+
126+
r = torch.sigmoid(self.ffn_receptance(xr)) # [Batch, L, hiddle_size]
127+
k = torch.relu(self.ffn_key(xk)).pow(2)
128+
129+
output = r * self.ffn_value(k)
130+
return output, state
131+
132+
def time_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
133+
batch_size, H, S = x.size(0), self.n_head, self.head_size
134+
x, state, g = self.time_mixing_jit(x, state, i, batch_size, H, S)
135+
136+
x = self.time_mixing_jit2(x, g)
137+
138+
return x, state
139+
140+
@MyFunction
141+
def time_mixing_jit(self, x: torch.Tensor, state: torch.Tensor, i: int,
142+
batch_size: int, H: int, S: int):
143+
i1 = (2 + S) * i + 1 # i is the block number
144+
145+
sx = state[:, i1] - x
146+
state[:, i1] = x # Information is compressed to position 1 on each layer
147+
148+
xxx = x + sx * self.att_time_maa_x
149+
xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1)
150+
xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1)
151+
mw, mk, mv, mr, mg = xxx.unbind(dim=1)
152+
153+
xw = x + sx * (self.att_time_maa_w + mw)
154+
xk = x + sx * (self.att_time_maa_k + mk)
155+
xv = x + sx * (self.att_time_maa_v + mv)
156+
xr = x + sx * (self.att_time_maa_r + mr)
157+
xg = x + sx * (self.att_time_maa_g + mg)
158+
159+
# calculate w, r, k, v, g
160+
w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2))
161+
w = -torch.exp(w.view(batch_size, H, S, 1))
162+
163+
r = self.att_receptance(xr).view(batch_size, H, 1, S)
164+
k = self.att_key(xk).view(batch_size, H, S, 1)
165+
v = self.att_value(xv).view(batch_size, H, 1, S)
166+
g = self.silu(self.att_gate(xg))
167+
168+
# Update state using attention mechanism
169+
s = state[:, (2+S)*i+2:(2+S)*(i+1), :].view(batch_size, H, S, S)
170+
a = k @ v
171+
x = r @ (self.att_time_faaaa * a + s)
172+
s = a + torch.exp(w) * s
173+
# Update the attention parameters of the i-th layer STATE
174+
state[:, (2+S)*i+2:(2+S)*(i+1), :] = s.view(batch_size, S, -1)
175+
return x, state, g
176+
177+
@MyFunction
178+
def time_mixing_jit2(self, x:torch.Tensor, g):
179+
return self.att_output(self.att_group_norm(x.flatten(start_dim=1)) * g)
180+
181+
def time_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
182+
batch_size, L, H, S = x.size(0), x.size(1), self.n_head, self.head_size
183+
x, state, g = self.time_mixing_parallel_jit1(x, state, i, batch_size, L, H, S)
184+
185+
x = self.time_mixing_parallel_jit2(x, g, batch_size, L)
186+
187+
return x, state
188+
189+
@MyFunction
190+
def time_mixing_parallel_jit1(self, x: torch.Tensor, state: torch.Tensor, i: int,
191+
batch_size: int, L: int, H: int, S: int):
192+
i1 = (2 + S) * i + 1
193+
sx_lerp = torch.empty_like(x)
194+
sx_lerp[:, 0] = state[:, i1] - x[:, 0]
195+
196+
sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:]
197+
198+
state[:, i1] = x[:, -1]
199+
200+
xxx = x + sx_lerp * self.att_time_maa_x # torch.Size([B, L, hiddle_size])
201+
xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, L, 5, 1, -1) # att_time_maa_w1: [hiddle_size, 160]
202+
xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, L, 5, -1) # [Batch, L, 5, hiddle_size]
203+
204+
mw, mk, mv, mr, mg = xxx.unbind(dim=2) # [10, 100, hiddle_size]
205+
206+
xw = x + sx_lerp * (self.att_time_maa_w + mw) # torch.Size([B, L, hiddle_size])
207+
xk = x + sx_lerp * (self.att_time_maa_k + mk)
208+
xv = x + sx_lerp * (self.att_time_maa_v + mv)
209+
xr = x + sx_lerp * (self.att_time_maa_r + mr)
210+
xg = x + sx_lerp * (self.att_time_maa_g + mg)
211+
212+
w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2))
213+
w = -torch.exp(w.view(batch_size, L, H, S, 1))
214+
215+
r = self.att_receptance(xr).view(batch_size, L, H, 1, S)
216+
k = self.att_key(xk).view(batch_size, L, H, S, 1)
217+
v = self.att_value(xv).view(batch_size, L, H, 1, S)
218+
g = self.silu(self.att_gate(xg)) # [10, 100, hiddle_size]
219+
# TODO, apply kernel here, cuda or fla(triton)
220+
221+
222+
w = torch.exp(w)
223+
s = state[:, (2+S)*i+2:(2+S)*(i+1)].view(batch_size, H, S, S)
224+
a = k @ v # a: [batch_size, L, H, S, S]
225+
226+
state_s = torch.zeros(batch_size, L, H, S, S, dtype=x.dtype, device=x.device)
227+
state_s[:, 0] = s
228+
229+
for l in range(L-1):
230+
s = a[:, l] + w[:, l] * s
231+
state_s[:, l+1] = s
232+
s = a[:, -1] + w[:, -1] * s
233+
234+
state[:, (2+S)*i+2:(2+S)*(i+1)] = s.view(batch_size, S, -1)
235+
236+
x = r @ (self.att_time_faaaa * a + state_s)
237+
return x, state, g
238+
239+
@MyFunction
240+
def time_mixing_parallel_jit2(self, x: torch.Tensor, g: torch.Tensor, batch_size: int, L:int):
241+
return self.att_output(self.att_group_norm(x.flatten(start_dim=2).view(batch_size * L, -1)).view(batch_size, L, -1) * g)
242+
243+
@torch.no_grad()
244+
def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
245+
x_time, state = self.time_mixing(self.ln1(x), state, i)
246+
x = x + x_time
247+
x_channel, state = self.channel_mixing(self.ln2(x), state, i)
248+
x = x + x_channel
249+
250+
return x, state
251+
252+
@torch.no_grad()
253+
def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor:
254+
# Time mixing
255+
x_time, state = self.time_mixing_parallel(self.ln1(x), state, i)
256+
x = x + x_time
257+
258+
# Channel mixing
259+
x_channel, state = self.channel_mixing_parallel(self.ln2(x), state, i)
260+
x = x + x_channel
261+
262+
return x, state
263+
264+
class RwkvModel(MyModule):
265+
def __init__(self, args: dict):
266+
super().__init__()
267+
self.args = args
268+
self.load_params()
269+
self.eval()
270+
271+
272+
273+
def load_params(self, load_from_file: bool = True, w: dict = None):
274+
# TODO: vllm
275+
if load_from_file:
276+
if not self.args['MODEL_NAME'].endswith('.pth'):
277+
self.args['MODEL_NAME'] += '.pth'
278+
w = torch.load(self.args['MODEL_NAME'], map_location="cpu")
279+
else:
280+
assert w is not None
281+
282+
self.num_layer = 0
283+
for k in w.keys():
284+
if '.time_' in k: w[k] = w[k].squeeze()
285+
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
286+
if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1]))
287+
self.num_layer += 1
288+
289+
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
290+
self.n_embd = w['blocks.0.ln1.weight'].shape[0]
291+
self.head_size = self.n_embd // self.n_head
292+
self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd]
293+
294+
self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True)
295+
296+
297+
self.ln0 = nn.LayerNorm(self.n_embd)
298+
self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight'])
299+
self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias'])
300+
301+
302+
self.blocks = nn.ModuleList()
303+
304+
for i in range(self.num_layer):
305+
block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k}
306+
self.blocks.append(Rwkv_Block(block_w, self.n_embd, self.n_head, self.args))
307+
308+
309+
self.ln_out = nn.LayerNorm(self.n_embd)
310+
self.ln_out.weight = nn.Parameter(w['ln_out.weight'])
311+
self.ln_out.bias = nn.Parameter(w['ln_out.bias'])
312+
313+
314+
self.head = nn.Linear(self.n_embd, self.args['vocab_size'], bias=False)
315+
self.head.weight = nn.Parameter(w['head.weight'])
316+
317+
318+
@torch.no_grad()
319+
def forward(self,
320+
input_ids: torch.Tensor,
321+
state: torch.Tensor,
322+
attn_metadata: AttentionMetadata,
323+
) -> Tuple[torch.Tensor, torch.Tensor]:
324+
325+
x = self.forward_jit1(input_ids)
326+
327+
if attn_metadata.prefill_metadata is not None:
328+
# Prefill phase
329+
for i, block in enumerate(self.blocks):
330+
x, state = block.forward_parallel(x, state, i)
331+
else:
332+
# Decoding phase
333+
for i, block in enumerate(self.blocks):
334+
x, state = block(x, state, i)
335+
336+
x = self.forward_jit2(x)
337+
338+
return x, state
339+
340+
341+
@MyFunction
342+
def forward_jit1(self, token: torch.Tensor) -> torch.Tensor:
343+
return self.ln0(self.emb(token))
344+
345+
@MyFunction
346+
def forward_jit2(self, x: torch.Tensor) -> torch.Tensor:
347+
return self.head(self.ln_out(x))
348+
349+
@torch.no_grad()
350+
def forward_parallel_slices(self,
351+
input_ids: torch.Tensor,
352+
state: torch.Tensor,
353+
attn_metadata: AttentionMetadata,
354+
slice_len: int = 64) -> Tuple[torch.Tensor, torch.Tensor]:
355+
"""
356+
Prefill forward with chunks of the RWKV6 model.
357+
Args:
358+
x (torch.Tensor): Input tensor, shape [Batch, N_embd].
359+
state (torch.Tensor): Hidden state tensor, shape [Batch, State Size, N_embd].
360+
i (int): Time index.
361+
Returns:
362+
torch.Tensor: Forward pass result tensor, shape same as input x.
363+
"""
364+
# FIXME!
365+
data_len = input_ids.shape[1]
366+
for i in range((data_len-2)//slice_len+1):
367+
start = i*slice_len
368+
end = min((i+1)*slice_len, data_len)
369+
input_ids_ith_chunk = input_ids[:, start:end]
370+
token_out, state = self.forward(input_ids_ith_chunk, state, attn_metadata)
371+
372+
return token_out, state
373+
374+
def init_state(self, batch_size: int) -> torch.Tensor:
375+
state = torch.zeros(batch_size, self.state_size[0], self.state_size[1])
376+
return state
377+

0 commit comments

Comments
 (0)