Skip to content

Commit 2805797

Browse files
committed
patch fused_recurrent_gated_delta_rule_fwd_kernel
Signed-off-by: Icey <1790571317@qq.com>
1 parent 33103dc commit 2805797

File tree

2 files changed

+3
-228
lines changed

2 files changed

+3
-228
lines changed

vllm_ascend/ops/sigmoid_gating.py

Lines changed: 2 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
# mypy: ignore-errors
1111

1212
import os
13-
from typing import Optional
1413

15-
import torch
16-
import vllm
14+
import vllm.model_executor.layers.fla.ops.fused_recurrent
1715
from vllm.triton_utils import tl, tldevice, triton
1816

1917
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
@@ -179,227 +177,4 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
179177
# p_beta += HV * (V if IS_BETA_HEADWISE else 1)
180178

181179

182-
def fused_recurrent_gated_delta_rule_fwd(
183-
q: torch.Tensor,
184-
k: torch.Tensor,
185-
v: torch.Tensor,
186-
g: torch.Tensor,
187-
beta: torch.Tensor,
188-
scale: float,
189-
initial_state: torch.Tensor,
190-
inplace_final_state: bool = True,
191-
cu_seqlens: Optional[torch.LongTensor] = None,
192-
ssm_state_indices: Optional[torch.Tensor] = None,
193-
num_accepted_tokens: Optional[torch.Tensor] = None,
194-
use_qk_l2norm_in_kernel: bool = False,
195-
) -> tuple[torch.Tensor, torch.Tensor]:
196-
B, T, H, K, V = *k.shape, v.shape[-1]
197-
HV = v.shape[2]
198-
N = B if cu_seqlens is None else len(cu_seqlens) - 1
199-
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
200-
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
201-
assert NK == 1, "NK > 1 is not supported yet"
202-
num_stages = 3
203-
num_warps = 1
204-
205-
o = q.new_empty(NK, *v.shape)
206-
if inplace_final_state:
207-
final_state = initial_state
208-
else:
209-
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
210-
211-
stride_init_state_token = initial_state.stride(0)
212-
stride_final_state_token = final_state.stride(0)
213-
214-
if ssm_state_indices is None:
215-
stride_indices_seq, stride_indices_tok = 1, 1
216-
elif ssm_state_indices.ndim == 1:
217-
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
218-
else:
219-
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
220-
221-
# print("N: ", N)
222-
# print("T: ", T)
223-
# print("B: ", B)
224-
# print("H: ", H)
225-
# print("HV: ", HV)
226-
# print("K: ", K)
227-
# print("V: ", V)
228-
# print("BK: ", BK)
229-
# print("BV: ", BV)
230-
231-
grid = (NK, NV, N * HV)
232-
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
233-
q=q,
234-
k=k,
235-
v=v,
236-
g=g,
237-
beta=beta,
238-
o=o,
239-
h0=initial_state,
240-
ht=final_state,
241-
cu_seqlens=cu_seqlens,
242-
ssm_state_indices=ssm_state_indices,
243-
num_accepted_tokens=num_accepted_tokens,
244-
scale=scale,
245-
N=N,
246-
T=T,
247-
B=B,
248-
H=H,
249-
HV=HV,
250-
K=K,
251-
V=V,
252-
BK=BK,
253-
BV=BV,
254-
stride_init_state_token=stride_init_state_token,
255-
stride_final_state_token=stride_final_state_token,
256-
stride_indices_seq=stride_indices_seq,
257-
stride_indices_tok=stride_indices_tok,
258-
IS_BETA_HEADWISE=beta.ndim == v.ndim,
259-
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
260-
INPLACE_FINAL_STATE=inplace_final_state,
261-
num_warps=num_warps,
262-
num_stages=num_stages,
263-
)
264-
o = o.squeeze(0)
265-
return o, final_state
266-
267-
268-
class FusedRecurrentFunction(torch.autograd.Function):
269-
270-
@staticmethod
271-
def forward(ctx,
272-
q: torch.Tensor,
273-
k: torch.Tensor,
274-
v: torch.Tensor,
275-
g: torch.Tensor,
276-
beta: torch.Tensor,
277-
scale: float,
278-
initial_state: torch.Tensor,
279-
inplace_final_state: bool = True,
280-
cu_seqlens: Optional[torch.LongTensor] = None,
281-
ssm_state_indices: Optional[torch.Tensor] = None,
282-
num_accepted_tokens: Optional[torch.Tensor] = None,
283-
use_qk_l2norm_in_kernel: bool = False):
284-
o, final_state = fused_recurrent_gated_delta_rule_fwd(
285-
q=q.contiguous(),
286-
k=k.contiguous(),
287-
v=v.contiguous(),
288-
g=g.contiguous(),
289-
beta=beta.contiguous(),
290-
scale=scale,
291-
initial_state=initial_state,
292-
inplace_final_state=inplace_final_state,
293-
cu_seqlens=cu_seqlens,
294-
ssm_state_indices=ssm_state_indices,
295-
num_accepted_tokens=num_accepted_tokens,
296-
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
297-
)
298-
299-
return o, final_state
300-
301-
302-
def fused_recurrent_gated_delta_rule(
303-
q: torch.Tensor,
304-
k: torch.Tensor,
305-
v: torch.Tensor,
306-
g: torch.Tensor,
307-
beta: torch.Tensor = None,
308-
scale: float = None,
309-
initial_state: torch.Tensor = None,
310-
inplace_final_state: bool = True,
311-
cu_seqlens: Optional[torch.LongTensor] = None,
312-
ssm_state_indices: Optional[torch.Tensor] = None,
313-
num_accepted_tokens: Optional[torch.Tensor] = None,
314-
use_qk_l2norm_in_kernel: bool = False,
315-
) -> tuple[torch.Tensor, torch.Tensor]:
316-
r"""
317-
Args:
318-
q (torch.Tensor):
319-
queries of shape `[B, T, H, K]`.
320-
k (torch.Tensor):
321-
keys of shape `[B, T, H, K]`.
322-
v (torch.Tensor):
323-
values of shape `[B, T, HV, V]`.
324-
GVA is applied if `HV > H`.
325-
g (torch.Tensor):
326-
g (decays) of shape `[B, T, HV]`.
327-
beta (torch.Tensor):
328-
betas of shape `[B, T, HV]`.
329-
scale (Optional[int]):
330-
Scale factor for the RetNet attention scores.
331-
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
332-
initial_state (Optional[torch.Tensor]):
333-
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
334-
For equal-length input sequences, `N` equals the batch size `B`.
335-
Default: `None`.
336-
inplace_final_state: bool:
337-
Whether to store the final state in-place to save memory.
338-
Default: `True`.
339-
cu_seqlens (torch.LongTensor):
340-
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
341-
consistent with the FlashAttention API.
342-
ssm_state_indices (Optional[torch.Tensor]):
343-
Indices to map the input sequences to the initial/final states.
344-
num_accepted_tokens (Optional[torch.Tensor]):
345-
Number of accepted tokens for each sequence during decoding.
346-
Returns:
347-
o (torch.Tensor):
348-
Outputs of shape `[B, T, HV, V]`.
349-
final_state (torch.Tensor):
350-
Final state of shape `[N, HV, K, V]`.
351-
Examples::
352-
>>> import torch
353-
>>> import torch.nn.functional as F
354-
>>> from einops import rearrange
355-
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
356-
# inputs with equal lengths
357-
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
358-
>>> q = torch.randn(B, T, H, K, device='cuda')
359-
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
360-
>>> v = torch.randn(B, T, HV, V, device='cuda')
361-
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
362-
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
363-
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
364-
>>> o, ht = fused_gated_recurrent_delta_rule(
365-
q, k, v, g, beta,
366-
initial_state=h0,
367-
)
368-
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
369-
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
370-
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
371-
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
372-
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
373-
q, k, v, g, beta,
374-
initial_state=h0,
375-
cu_seqlens=cu_seqlens
376-
)
377-
"""
378-
if cu_seqlens is not None and q.shape[0] != 1:
379-
raise ValueError(
380-
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
381-
f"Please flatten variable-length inputs before processing.")
382-
if scale is None:
383-
scale = k.shape[-1]**-0.5
384-
else:
385-
assert scale > 0, "scale must be positive"
386-
if beta is None:
387-
beta = torch.ones_like(q[..., 0])
388-
o, final_state = FusedRecurrentFunction.apply(
389-
q,
390-
k,
391-
v,
392-
g,
393-
beta,
394-
scale,
395-
initial_state,
396-
inplace_final_state,
397-
cu_seqlens,
398-
ssm_state_indices,
399-
num_accepted_tokens,
400-
use_qk_l2norm_in_kernel,
401-
)
402-
return o, final_state
403-
404-
405-
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule
180+
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717

1818
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
1919
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
20-
20+
import vllm_ascend.ops.sigmoid_gating
2121
# TODO: revert me when triton import is fixed
2222
# import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa

0 commit comments

Comments
 (0)