|
10 | 10 | # mypy: ignore-errors
|
11 | 11 |
|
12 | 12 | import os
|
| 13 | +from typing import Optional |
13 | 14 |
|
14 |
| -from vllm.model_executor.layers.fla.ops import fused_recurrent |
| 15 | +import torch |
| 16 | +import vllm |
15 | 17 | from vllm.triton_utils import tl, tldevice, triton
|
16 | 18 |
|
17 | 19 | if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
@@ -177,4 +179,227 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
177 | 179 | # p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
178 | 180 |
|
179 | 181 |
|
180 |
| -fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel |
| 182 | +def fused_recurrent_gated_delta_rule_fwd( |
| 183 | + q: torch.Tensor, |
| 184 | + k: torch.Tensor, |
| 185 | + v: torch.Tensor, |
| 186 | + g: torch.Tensor, |
| 187 | + beta: torch.Tensor, |
| 188 | + scale: float, |
| 189 | + initial_state: torch.Tensor, |
| 190 | + inplace_final_state: bool = True, |
| 191 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 192 | + ssm_state_indices: Optional[torch.Tensor] = None, |
| 193 | + num_accepted_tokens: Optional[torch.Tensor] = None, |
| 194 | + use_qk_l2norm_in_kernel: bool = False, |
| 195 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 196 | + B, T, H, K, V = *k.shape, v.shape[-1] |
| 197 | + HV = v.shape[2] |
| 198 | + N = B if cu_seqlens is None else len(cu_seqlens) - 1 |
| 199 | + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) |
| 200 | + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) |
| 201 | + assert NK == 1, "NK > 1 is not supported yet" |
| 202 | + num_stages = 3 |
| 203 | + num_warps = 1 |
| 204 | + |
| 205 | + o = q.new_empty(NK, *v.shape) |
| 206 | + if inplace_final_state: |
| 207 | + final_state = initial_state |
| 208 | + else: |
| 209 | + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) |
| 210 | + |
| 211 | + stride_init_state_token = initial_state.stride(0) |
| 212 | + stride_final_state_token = final_state.stride(0) |
| 213 | + |
| 214 | + if ssm_state_indices is None: |
| 215 | + stride_indices_seq, stride_indices_tok = 1, 1 |
| 216 | + elif ssm_state_indices.ndim == 1: |
| 217 | + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 |
| 218 | + else: |
| 219 | + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() |
| 220 | + |
| 221 | + # print("N: ", N) |
| 222 | + # print("T: ", T) |
| 223 | + # print("B: ", B) |
| 224 | + # print("H: ", H) |
| 225 | + # print("HV: ", HV) |
| 226 | + # print("K: ", K) |
| 227 | + # print("V: ", V) |
| 228 | + # print("BK: ", BK) |
| 229 | + # print("BV: ", BV) |
| 230 | + |
| 231 | + grid = (NK, NV, N * HV) |
| 232 | + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( |
| 233 | + q=q, |
| 234 | + k=k, |
| 235 | + v=v, |
| 236 | + g=g, |
| 237 | + beta=beta, |
| 238 | + o=o, |
| 239 | + h0=initial_state, |
| 240 | + ht=final_state, |
| 241 | + cu_seqlens=cu_seqlens, |
| 242 | + ssm_state_indices=ssm_state_indices, |
| 243 | + num_accepted_tokens=num_accepted_tokens, |
| 244 | + scale=scale, |
| 245 | + N=N, |
| 246 | + T=T, |
| 247 | + B=B, |
| 248 | + H=H, |
| 249 | + HV=HV, |
| 250 | + K=K, |
| 251 | + V=V, |
| 252 | + BK=BK, |
| 253 | + BV=BV, |
| 254 | + stride_init_state_token=stride_init_state_token, |
| 255 | + stride_final_state_token=stride_final_state_token, |
| 256 | + stride_indices_seq=stride_indices_seq, |
| 257 | + stride_indices_tok=stride_indices_tok, |
| 258 | + IS_BETA_HEADWISE=beta.ndim == v.ndim, |
| 259 | + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, |
| 260 | + INPLACE_FINAL_STATE=inplace_final_state, |
| 261 | + num_warps=num_warps, |
| 262 | + num_stages=num_stages, |
| 263 | + ) |
| 264 | + o = o.squeeze(0) |
| 265 | + return o, final_state |
| 266 | + |
| 267 | + |
| 268 | +class FusedRecurrentFunction(torch.autograd.Function): |
| 269 | + |
| 270 | + @staticmethod |
| 271 | + def forward(ctx, |
| 272 | + q: torch.Tensor, |
| 273 | + k: torch.Tensor, |
| 274 | + v: torch.Tensor, |
| 275 | + g: torch.Tensor, |
| 276 | + beta: torch.Tensor, |
| 277 | + scale: float, |
| 278 | + initial_state: torch.Tensor, |
| 279 | + inplace_final_state: bool = True, |
| 280 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 281 | + ssm_state_indices: Optional[torch.Tensor] = None, |
| 282 | + num_accepted_tokens: Optional[torch.Tensor] = None, |
| 283 | + use_qk_l2norm_in_kernel: bool = False): |
| 284 | + o, final_state = fused_recurrent_gated_delta_rule_fwd( |
| 285 | + q=q.contiguous(), |
| 286 | + k=k.contiguous(), |
| 287 | + v=v.contiguous(), |
| 288 | + g=g.contiguous(), |
| 289 | + beta=beta.contiguous(), |
| 290 | + scale=scale, |
| 291 | + initial_state=initial_state, |
| 292 | + inplace_final_state=inplace_final_state, |
| 293 | + cu_seqlens=cu_seqlens, |
| 294 | + ssm_state_indices=ssm_state_indices, |
| 295 | + num_accepted_tokens=num_accepted_tokens, |
| 296 | + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| 297 | + ) |
| 298 | + |
| 299 | + return o, final_state |
| 300 | + |
| 301 | + |
| 302 | +def fused_recurrent_gated_delta_rule( |
| 303 | + q: torch.Tensor, |
| 304 | + k: torch.Tensor, |
| 305 | + v: torch.Tensor, |
| 306 | + g: torch.Tensor, |
| 307 | + beta: torch.Tensor = None, |
| 308 | + scale: float = None, |
| 309 | + initial_state: torch.Tensor = None, |
| 310 | + inplace_final_state: bool = True, |
| 311 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 312 | + ssm_state_indices: Optional[torch.Tensor] = None, |
| 313 | + num_accepted_tokens: Optional[torch.Tensor] = None, |
| 314 | + use_qk_l2norm_in_kernel: bool = False, |
| 315 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 316 | + r""" |
| 317 | + Args: |
| 318 | + q (torch.Tensor): |
| 319 | + queries of shape `[B, T, H, K]`. |
| 320 | + k (torch.Tensor): |
| 321 | + keys of shape `[B, T, H, K]`. |
| 322 | + v (torch.Tensor): |
| 323 | + values of shape `[B, T, HV, V]`. |
| 324 | + GVA is applied if `HV > H`. |
| 325 | + g (torch.Tensor): |
| 326 | + g (decays) of shape `[B, T, HV]`. |
| 327 | + beta (torch.Tensor): |
| 328 | + betas of shape `[B, T, HV]`. |
| 329 | + scale (Optional[int]): |
| 330 | + Scale factor for the RetNet attention scores. |
| 331 | + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
| 332 | + initial_state (Optional[torch.Tensor]): |
| 333 | + Initial state of shape `[N, HV, K, V]` for `N` input sequences. |
| 334 | + For equal-length input sequences, `N` equals the batch size `B`. |
| 335 | + Default: `None`. |
| 336 | + inplace_final_state: bool: |
| 337 | + Whether to store the final state in-place to save memory. |
| 338 | + Default: `True`. |
| 339 | + cu_seqlens (torch.LongTensor): |
| 340 | + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
| 341 | + consistent with the FlashAttention API. |
| 342 | + ssm_state_indices (Optional[torch.Tensor]): |
| 343 | + Indices to map the input sequences to the initial/final states. |
| 344 | + num_accepted_tokens (Optional[torch.Tensor]): |
| 345 | + Number of accepted tokens for each sequence during decoding. |
| 346 | + Returns: |
| 347 | + o (torch.Tensor): |
| 348 | + Outputs of shape `[B, T, HV, V]`. |
| 349 | + final_state (torch.Tensor): |
| 350 | + Final state of shape `[N, HV, K, V]`. |
| 351 | + Examples:: |
| 352 | + >>> import torch |
| 353 | + >>> import torch.nn.functional as F |
| 354 | + >>> from einops import rearrange |
| 355 | + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule |
| 356 | + # inputs with equal lengths |
| 357 | + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 |
| 358 | + >>> q = torch.randn(B, T, H, K, device='cuda') |
| 359 | + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) |
| 360 | + >>> v = torch.randn(B, T, HV, V, device='cuda') |
| 361 | + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) |
| 362 | + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() |
| 363 | + >>> h0 = torch.randn(B, HV, K, V, device='cuda') |
| 364 | + >>> o, ht = fused_gated_recurrent_delta_rule( |
| 365 | + q, k, v, g, beta, |
| 366 | + initial_state=h0, |
| 367 | + ) |
| 368 | + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required |
| 369 | + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) |
| 370 | + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected |
| 371 | + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
| 372 | + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( |
| 373 | + q, k, v, g, beta, |
| 374 | + initial_state=h0, |
| 375 | + cu_seqlens=cu_seqlens |
| 376 | + ) |
| 377 | + """ |
| 378 | + if cu_seqlens is not None and q.shape[0] != 1: |
| 379 | + raise ValueError( |
| 380 | + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." |
| 381 | + f"Please flatten variable-length inputs before processing.") |
| 382 | + if scale is None: |
| 383 | + scale = k.shape[-1]**-0.5 |
| 384 | + else: |
| 385 | + assert scale > 0, "scale must be positive" |
| 386 | + if beta is None: |
| 387 | + beta = torch.ones_like(q[..., 0]) |
| 388 | + o, final_state = FusedRecurrentFunction.apply( |
| 389 | + q, |
| 390 | + k, |
| 391 | + v, |
| 392 | + g, |
| 393 | + beta, |
| 394 | + scale, |
| 395 | + initial_state, |
| 396 | + inplace_final_state, |
| 397 | + cu_seqlens, |
| 398 | + ssm_state_indices, |
| 399 | + num_accepted_tokens, |
| 400 | + use_qk_l2norm_in_kernel, |
| 401 | + ) |
| 402 | + return o, final_state |
| 403 | + |
| 404 | + |
| 405 | +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule |
0 commit comments