|
10 | 10 | # mypy: ignore-errors
|
11 | 11 |
|
12 | 12 | import os
|
13 |
| -from typing import Optional |
14 | 13 |
|
15 |
| -import torch |
16 |
| -import vllm |
| 14 | +import vllm.model_executor.layers.fla.ops.fused_recurrent |
17 | 15 | from vllm.triton_utils import tl, tldevice, triton
|
18 | 16 |
|
19 | 17 | if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
@@ -179,227 +177,4 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
179 | 177 | # p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
180 | 178 |
|
181 | 179 |
|
182 |
| -def fused_recurrent_gated_delta_rule_fwd( |
183 |
| - q: torch.Tensor, |
184 |
| - k: torch.Tensor, |
185 |
| - v: torch.Tensor, |
186 |
| - g: torch.Tensor, |
187 |
| - beta: torch.Tensor, |
188 |
| - scale: float, |
189 |
| - initial_state: torch.Tensor, |
190 |
| - inplace_final_state: bool = True, |
191 |
| - cu_seqlens: Optional[torch.LongTensor] = None, |
192 |
| - ssm_state_indices: Optional[torch.Tensor] = None, |
193 |
| - num_accepted_tokens: Optional[torch.Tensor] = None, |
194 |
| - use_qk_l2norm_in_kernel: bool = False, |
195 |
| -) -> tuple[torch.Tensor, torch.Tensor]: |
196 |
| - B, T, H, K, V = *k.shape, v.shape[-1] |
197 |
| - HV = v.shape[2] |
198 |
| - N = B if cu_seqlens is None else len(cu_seqlens) - 1 |
199 |
| - BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) |
200 |
| - NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) |
201 |
| - assert NK == 1, "NK > 1 is not supported yet" |
202 |
| - num_stages = 3 |
203 |
| - num_warps = 1 |
204 |
| - |
205 |
| - o = q.new_empty(NK, *v.shape) |
206 |
| - if inplace_final_state: |
207 |
| - final_state = initial_state |
208 |
| - else: |
209 |
| - final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) |
210 |
| - |
211 |
| - stride_init_state_token = initial_state.stride(0) |
212 |
| - stride_final_state_token = final_state.stride(0) |
213 |
| - |
214 |
| - if ssm_state_indices is None: |
215 |
| - stride_indices_seq, stride_indices_tok = 1, 1 |
216 |
| - elif ssm_state_indices.ndim == 1: |
217 |
| - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 |
218 |
| - else: |
219 |
| - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() |
220 |
| - |
221 |
| - # print("N: ", N) |
222 |
| - # print("T: ", T) |
223 |
| - # print("B: ", B) |
224 |
| - # print("H: ", H) |
225 |
| - # print("HV: ", HV) |
226 |
| - # print("K: ", K) |
227 |
| - # print("V: ", V) |
228 |
| - # print("BK: ", BK) |
229 |
| - # print("BV: ", BV) |
230 |
| - |
231 |
| - grid = (NK, NV, N * HV) |
232 |
| - fused_recurrent_gated_delta_rule_fwd_kernel[grid]( |
233 |
| - q=q, |
234 |
| - k=k, |
235 |
| - v=v, |
236 |
| - g=g, |
237 |
| - beta=beta, |
238 |
| - o=o, |
239 |
| - h0=initial_state, |
240 |
| - ht=final_state, |
241 |
| - cu_seqlens=cu_seqlens, |
242 |
| - ssm_state_indices=ssm_state_indices, |
243 |
| - num_accepted_tokens=num_accepted_tokens, |
244 |
| - scale=scale, |
245 |
| - N=N, |
246 |
| - T=T, |
247 |
| - B=B, |
248 |
| - H=H, |
249 |
| - HV=HV, |
250 |
| - K=K, |
251 |
| - V=V, |
252 |
| - BK=BK, |
253 |
| - BV=BV, |
254 |
| - stride_init_state_token=stride_init_state_token, |
255 |
| - stride_final_state_token=stride_final_state_token, |
256 |
| - stride_indices_seq=stride_indices_seq, |
257 |
| - stride_indices_tok=stride_indices_tok, |
258 |
| - IS_BETA_HEADWISE=beta.ndim == v.ndim, |
259 |
| - USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, |
260 |
| - INPLACE_FINAL_STATE=inplace_final_state, |
261 |
| - num_warps=num_warps, |
262 |
| - num_stages=num_stages, |
263 |
| - ) |
264 |
| - o = o.squeeze(0) |
265 |
| - return o, final_state |
266 |
| - |
267 |
| - |
268 |
| -class FusedRecurrentFunction(torch.autograd.Function): |
269 |
| - |
270 |
| - @staticmethod |
271 |
| - def forward(ctx, |
272 |
| - q: torch.Tensor, |
273 |
| - k: torch.Tensor, |
274 |
| - v: torch.Tensor, |
275 |
| - g: torch.Tensor, |
276 |
| - beta: torch.Tensor, |
277 |
| - scale: float, |
278 |
| - initial_state: torch.Tensor, |
279 |
| - inplace_final_state: bool = True, |
280 |
| - cu_seqlens: Optional[torch.LongTensor] = None, |
281 |
| - ssm_state_indices: Optional[torch.Tensor] = None, |
282 |
| - num_accepted_tokens: Optional[torch.Tensor] = None, |
283 |
| - use_qk_l2norm_in_kernel: bool = False): |
284 |
| - o, final_state = fused_recurrent_gated_delta_rule_fwd( |
285 |
| - q=q.contiguous(), |
286 |
| - k=k.contiguous(), |
287 |
| - v=v.contiguous(), |
288 |
| - g=g.contiguous(), |
289 |
| - beta=beta.contiguous(), |
290 |
| - scale=scale, |
291 |
| - initial_state=initial_state, |
292 |
| - inplace_final_state=inplace_final_state, |
293 |
| - cu_seqlens=cu_seqlens, |
294 |
| - ssm_state_indices=ssm_state_indices, |
295 |
| - num_accepted_tokens=num_accepted_tokens, |
296 |
| - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
297 |
| - ) |
298 |
| - |
299 |
| - return o, final_state |
300 |
| - |
301 |
| - |
302 |
| -def fused_recurrent_gated_delta_rule( |
303 |
| - q: torch.Tensor, |
304 |
| - k: torch.Tensor, |
305 |
| - v: torch.Tensor, |
306 |
| - g: torch.Tensor, |
307 |
| - beta: torch.Tensor = None, |
308 |
| - scale: float = None, |
309 |
| - initial_state: torch.Tensor = None, |
310 |
| - inplace_final_state: bool = True, |
311 |
| - cu_seqlens: Optional[torch.LongTensor] = None, |
312 |
| - ssm_state_indices: Optional[torch.Tensor] = None, |
313 |
| - num_accepted_tokens: Optional[torch.Tensor] = None, |
314 |
| - use_qk_l2norm_in_kernel: bool = False, |
315 |
| -) -> tuple[torch.Tensor, torch.Tensor]: |
316 |
| - r""" |
317 |
| - Args: |
318 |
| - q (torch.Tensor): |
319 |
| - queries of shape `[B, T, H, K]`. |
320 |
| - k (torch.Tensor): |
321 |
| - keys of shape `[B, T, H, K]`. |
322 |
| - v (torch.Tensor): |
323 |
| - values of shape `[B, T, HV, V]`. |
324 |
| - GVA is applied if `HV > H`. |
325 |
| - g (torch.Tensor): |
326 |
| - g (decays) of shape `[B, T, HV]`. |
327 |
| - beta (torch.Tensor): |
328 |
| - betas of shape `[B, T, HV]`. |
329 |
| - scale (Optional[int]): |
330 |
| - Scale factor for the RetNet attention scores. |
331 |
| - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
332 |
| - initial_state (Optional[torch.Tensor]): |
333 |
| - Initial state of shape `[N, HV, K, V]` for `N` input sequences. |
334 |
| - For equal-length input sequences, `N` equals the batch size `B`. |
335 |
| - Default: `None`. |
336 |
| - inplace_final_state: bool: |
337 |
| - Whether to store the final state in-place to save memory. |
338 |
| - Default: `True`. |
339 |
| - cu_seqlens (torch.LongTensor): |
340 |
| - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
341 |
| - consistent with the FlashAttention API. |
342 |
| - ssm_state_indices (Optional[torch.Tensor]): |
343 |
| - Indices to map the input sequences to the initial/final states. |
344 |
| - num_accepted_tokens (Optional[torch.Tensor]): |
345 |
| - Number of accepted tokens for each sequence during decoding. |
346 |
| - Returns: |
347 |
| - o (torch.Tensor): |
348 |
| - Outputs of shape `[B, T, HV, V]`. |
349 |
| - final_state (torch.Tensor): |
350 |
| - Final state of shape `[N, HV, K, V]`. |
351 |
| - Examples:: |
352 |
| - >>> import torch |
353 |
| - >>> import torch.nn.functional as F |
354 |
| - >>> from einops import rearrange |
355 |
| - >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule |
356 |
| - # inputs with equal lengths |
357 |
| - >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 |
358 |
| - >>> q = torch.randn(B, T, H, K, device='cuda') |
359 |
| - >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) |
360 |
| - >>> v = torch.randn(B, T, HV, V, device='cuda') |
361 |
| - >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) |
362 |
| - >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() |
363 |
| - >>> h0 = torch.randn(B, HV, K, V, device='cuda') |
364 |
| - >>> o, ht = fused_gated_recurrent_delta_rule( |
365 |
| - q, k, v, g, beta, |
366 |
| - initial_state=h0, |
367 |
| - ) |
368 |
| - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required |
369 |
| - >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) |
370 |
| - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected |
371 |
| - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
372 |
| - >>> o_var, ht_var = fused_gated_recurrent_delta_rule( |
373 |
| - q, k, v, g, beta, |
374 |
| - initial_state=h0, |
375 |
| - cu_seqlens=cu_seqlens |
376 |
| - ) |
377 |
| - """ |
378 |
| - if cu_seqlens is not None and q.shape[0] != 1: |
379 |
| - raise ValueError( |
380 |
| - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." |
381 |
| - f"Please flatten variable-length inputs before processing.") |
382 |
| - if scale is None: |
383 |
| - scale = k.shape[-1]**-0.5 |
384 |
| - else: |
385 |
| - assert scale > 0, "scale must be positive" |
386 |
| - if beta is None: |
387 |
| - beta = torch.ones_like(q[..., 0]) |
388 |
| - o, final_state = FusedRecurrentFunction.apply( |
389 |
| - q, |
390 |
| - k, |
391 |
| - v, |
392 |
| - g, |
393 |
| - beta, |
394 |
| - scale, |
395 |
| - initial_state, |
396 |
| - inplace_final_state, |
397 |
| - cu_seqlens, |
398 |
| - ssm_state_indices, |
399 |
| - num_accepted_tokens, |
400 |
| - use_qk_l2norm_in_kernel, |
401 |
| - ) |
402 |
| - return o, final_state |
403 |
| - |
404 |
| - |
405 |
| -vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule |
| 180 | +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel |
0 commit comments