Skip to content

Commit 33103dc

Browse files
committed
fix
Signed-off-by: Icey <1790571317@qq.com>
1 parent a1abfb2 commit 33103dc

File tree

1 file changed

+227
-2
lines changed

1 file changed

+227
-2
lines changed

vllm_ascend/ops/sigmoid_gating.py

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
# mypy: ignore-errors
1111

1212
import os
13+
from typing import Optional
1314

14-
from vllm.model_executor.layers.fla.ops import fused_recurrent
15+
import torch
16+
import vllm
1517
from vllm.triton_utils import tl, tldevice, triton
1618

1719
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
@@ -177,4 +179,227 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
177179
# p_beta += HV * (V if IS_BETA_HEADWISE else 1)
178180

179181

180-
fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
182+
def fused_recurrent_gated_delta_rule_fwd(
183+
q: torch.Tensor,
184+
k: torch.Tensor,
185+
v: torch.Tensor,
186+
g: torch.Tensor,
187+
beta: torch.Tensor,
188+
scale: float,
189+
initial_state: torch.Tensor,
190+
inplace_final_state: bool = True,
191+
cu_seqlens: Optional[torch.LongTensor] = None,
192+
ssm_state_indices: Optional[torch.Tensor] = None,
193+
num_accepted_tokens: Optional[torch.Tensor] = None,
194+
use_qk_l2norm_in_kernel: bool = False,
195+
) -> tuple[torch.Tensor, torch.Tensor]:
196+
B, T, H, K, V = *k.shape, v.shape[-1]
197+
HV = v.shape[2]
198+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
199+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
200+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
201+
assert NK == 1, "NK > 1 is not supported yet"
202+
num_stages = 3
203+
num_warps = 1
204+
205+
o = q.new_empty(NK, *v.shape)
206+
if inplace_final_state:
207+
final_state = initial_state
208+
else:
209+
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
210+
211+
stride_init_state_token = initial_state.stride(0)
212+
stride_final_state_token = final_state.stride(0)
213+
214+
if ssm_state_indices is None:
215+
stride_indices_seq, stride_indices_tok = 1, 1
216+
elif ssm_state_indices.ndim == 1:
217+
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
218+
else:
219+
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
220+
221+
# print("N: ", N)
222+
# print("T: ", T)
223+
# print("B: ", B)
224+
# print("H: ", H)
225+
# print("HV: ", HV)
226+
# print("K: ", K)
227+
# print("V: ", V)
228+
# print("BK: ", BK)
229+
# print("BV: ", BV)
230+
231+
grid = (NK, NV, N * HV)
232+
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
233+
q=q,
234+
k=k,
235+
v=v,
236+
g=g,
237+
beta=beta,
238+
o=o,
239+
h0=initial_state,
240+
ht=final_state,
241+
cu_seqlens=cu_seqlens,
242+
ssm_state_indices=ssm_state_indices,
243+
num_accepted_tokens=num_accepted_tokens,
244+
scale=scale,
245+
N=N,
246+
T=T,
247+
B=B,
248+
H=H,
249+
HV=HV,
250+
K=K,
251+
V=V,
252+
BK=BK,
253+
BV=BV,
254+
stride_init_state_token=stride_init_state_token,
255+
stride_final_state_token=stride_final_state_token,
256+
stride_indices_seq=stride_indices_seq,
257+
stride_indices_tok=stride_indices_tok,
258+
IS_BETA_HEADWISE=beta.ndim == v.ndim,
259+
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
260+
INPLACE_FINAL_STATE=inplace_final_state,
261+
num_warps=num_warps,
262+
num_stages=num_stages,
263+
)
264+
o = o.squeeze(0)
265+
return o, final_state
266+
267+
268+
class FusedRecurrentFunction(torch.autograd.Function):
269+
270+
@staticmethod
271+
def forward(ctx,
272+
q: torch.Tensor,
273+
k: torch.Tensor,
274+
v: torch.Tensor,
275+
g: torch.Tensor,
276+
beta: torch.Tensor,
277+
scale: float,
278+
initial_state: torch.Tensor,
279+
inplace_final_state: bool = True,
280+
cu_seqlens: Optional[torch.LongTensor] = None,
281+
ssm_state_indices: Optional[torch.Tensor] = None,
282+
num_accepted_tokens: Optional[torch.Tensor] = None,
283+
use_qk_l2norm_in_kernel: bool = False):
284+
o, final_state = fused_recurrent_gated_delta_rule_fwd(
285+
q=q.contiguous(),
286+
k=k.contiguous(),
287+
v=v.contiguous(),
288+
g=g.contiguous(),
289+
beta=beta.contiguous(),
290+
scale=scale,
291+
initial_state=initial_state,
292+
inplace_final_state=inplace_final_state,
293+
cu_seqlens=cu_seqlens,
294+
ssm_state_indices=ssm_state_indices,
295+
num_accepted_tokens=num_accepted_tokens,
296+
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
297+
)
298+
299+
return o, final_state
300+
301+
302+
def fused_recurrent_gated_delta_rule(
303+
q: torch.Tensor,
304+
k: torch.Tensor,
305+
v: torch.Tensor,
306+
g: torch.Tensor,
307+
beta: torch.Tensor = None,
308+
scale: float = None,
309+
initial_state: torch.Tensor = None,
310+
inplace_final_state: bool = True,
311+
cu_seqlens: Optional[torch.LongTensor] = None,
312+
ssm_state_indices: Optional[torch.Tensor] = None,
313+
num_accepted_tokens: Optional[torch.Tensor] = None,
314+
use_qk_l2norm_in_kernel: bool = False,
315+
) -> tuple[torch.Tensor, torch.Tensor]:
316+
r"""
317+
Args:
318+
q (torch.Tensor):
319+
queries of shape `[B, T, H, K]`.
320+
k (torch.Tensor):
321+
keys of shape `[B, T, H, K]`.
322+
v (torch.Tensor):
323+
values of shape `[B, T, HV, V]`.
324+
GVA is applied if `HV > H`.
325+
g (torch.Tensor):
326+
g (decays) of shape `[B, T, HV]`.
327+
beta (torch.Tensor):
328+
betas of shape `[B, T, HV]`.
329+
scale (Optional[int]):
330+
Scale factor for the RetNet attention scores.
331+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
332+
initial_state (Optional[torch.Tensor]):
333+
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
334+
For equal-length input sequences, `N` equals the batch size `B`.
335+
Default: `None`.
336+
inplace_final_state: bool:
337+
Whether to store the final state in-place to save memory.
338+
Default: `True`.
339+
cu_seqlens (torch.LongTensor):
340+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
341+
consistent with the FlashAttention API.
342+
ssm_state_indices (Optional[torch.Tensor]):
343+
Indices to map the input sequences to the initial/final states.
344+
num_accepted_tokens (Optional[torch.Tensor]):
345+
Number of accepted tokens for each sequence during decoding.
346+
Returns:
347+
o (torch.Tensor):
348+
Outputs of shape `[B, T, HV, V]`.
349+
final_state (torch.Tensor):
350+
Final state of shape `[N, HV, K, V]`.
351+
Examples::
352+
>>> import torch
353+
>>> import torch.nn.functional as F
354+
>>> from einops import rearrange
355+
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
356+
# inputs with equal lengths
357+
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
358+
>>> q = torch.randn(B, T, H, K, device='cuda')
359+
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
360+
>>> v = torch.randn(B, T, HV, V, device='cuda')
361+
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
362+
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
363+
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
364+
>>> o, ht = fused_gated_recurrent_delta_rule(
365+
q, k, v, g, beta,
366+
initial_state=h0,
367+
)
368+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
369+
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
370+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
371+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
372+
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
373+
q, k, v, g, beta,
374+
initial_state=h0,
375+
cu_seqlens=cu_seqlens
376+
)
377+
"""
378+
if cu_seqlens is not None and q.shape[0] != 1:
379+
raise ValueError(
380+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
381+
f"Please flatten variable-length inputs before processing.")
382+
if scale is None:
383+
scale = k.shape[-1]**-0.5
384+
else:
385+
assert scale > 0, "scale must be positive"
386+
if beta is None:
387+
beta = torch.ones_like(q[..., 0])
388+
o, final_state = FusedRecurrentFunction.apply(
389+
q,
390+
k,
391+
v,
392+
g,
393+
beta,
394+
scale,
395+
initial_state,
396+
inplace_final_state,
397+
cu_seqlens,
398+
ssm_state_indices,
399+
num_accepted_tokens,
400+
use_qk_l2norm_in_kernel,
401+
)
402+
return o, final_state
403+
404+
405+
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule

0 commit comments

Comments
 (0)