|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Optional, Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch.library import register_fake |
| 11 | + |
| 12 | + |
| 13 | +torch.library.define( |
| 14 | + "blackwell_fmha::fmha_fwd", |
| 15 | + "(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)", |
| 16 | + tags=[torch.Tag.pt2_compliant_tag], |
| 17 | +) |
| 18 | + |
| 19 | +torch.library.define( |
| 20 | + "blackwell_fmha::fmha_bwd", |
| 21 | + "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)", |
| 22 | + tags=[torch.Tag.pt2_compliant_tag], |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +@torch.library.impl("blackwell_fmha::fmha_fwd", "cuda") |
| 27 | +def custom_op_fmha( |
| 28 | + q: torch.Tensor, |
| 29 | + k: torch.Tensor, |
| 30 | + v: torch.Tensor, |
| 31 | + cu_seqlens_q: torch.Tensor = None, |
| 32 | + cu_seqlens_k: torch.Tensor = None, |
| 33 | + max_seq_len_q: int = None, |
| 34 | + max_seq_len_k: int = None, |
| 35 | + softmax_scale: float = None, |
| 36 | + causal: bool = False, |
| 37 | + seqlen_kv: torch.Tensor = None, |
| 38 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 39 | + assert q.is_contiguous(), "q is not contiguous" |
| 40 | + assert k.is_contiguous(), "k is not contiguous" |
| 41 | + assert v.is_contiguous(), "v is not contiguous" |
| 42 | + assert q.is_cuda, "q must be on GPU" |
| 43 | + assert k.is_cuda, "k must be on GPU" |
| 44 | + assert v.is_cuda, "v must be on GPU" |
| 45 | + return torch.ops.fbgemm.fmha_fwd( |
| 46 | + q, |
| 47 | + k, |
| 48 | + v, |
| 49 | + cu_seqlens_q=cu_seqlens_q, |
| 50 | + cu_seqlens_k=cu_seqlens_k, |
| 51 | + max_seq_len_q=max_seq_len_q, |
| 52 | + max_seq_len_k=max_seq_len_k, |
| 53 | + softmax_scale=softmax_scale, |
| 54 | + causal=causal, |
| 55 | + seqlen_kv=seqlen_kv, |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +@register_fake("blackwell_fmha::fmha_fwd") |
| 60 | +def fmha_fwd_meta( |
| 61 | + q: torch.Tensor, |
| 62 | + k: torch.Tensor, |
| 63 | + v: torch.Tensor, |
| 64 | + cu_seqlens_q: torch.Tensor = None, |
| 65 | + cu_seqlens_k: torch.Tensor = None, |
| 66 | + max_seq_len_q: int = None, |
| 67 | + max_seq_len_k: int = None, |
| 68 | + softmax_scale: float = None, |
| 69 | + causal: bool = False, |
| 70 | + seqlen_kv: torch.Tensor = None, |
| 71 | +): |
| 72 | + if q.dtype == torch.float16: |
| 73 | + out_dtype = torch.float16 |
| 74 | + elif q.dtype == torch.bfloat16: |
| 75 | + out_dtype = torch.bfloat16 |
| 76 | + elif q.dtype == torch.float8_e4m3fn: |
| 77 | + # Output is BF16 when input is FP8 |
| 78 | + out_dtype = torch.bfloat16 |
| 79 | + else: |
| 80 | + raise RuntimeError(f"Unsupported dtype for q: {q.dtype}") |
| 81 | + |
| 82 | + kIsVarlen = max_seq_len_q is not None |
| 83 | + if kIsVarlen: |
| 84 | + SQ = q.shape[0] |
| 85 | + H_Q = q.shape[1] |
| 86 | + B = cu_seqlens_q.shape[0] - 1 |
| 87 | + else: |
| 88 | + SQ = q.shape[1] |
| 89 | + H_Q = q.shape[2] |
| 90 | + B = q.shape[0] |
| 91 | + device = q.device |
| 92 | + options2 = {"dtype": torch.float32, "device": device} |
| 93 | + if kIsVarlen: |
| 94 | + out = torch.empty_like(q, dtype=out_dtype) |
| 95 | + size = out.size() |
| 96 | + stride = out.stride() |
| 97 | + storage_offset = q.shape[-1] * max_seq_len_q * H_Q # example scalar offset |
| 98 | + out1 = torch.as_strided( |
| 99 | + out, size=size, stride=stride, storage_offset=storage_offset |
| 100 | + ) |
| 101 | + else: |
| 102 | + out1 = torch.empty_like(q, dtype=out_dtype) |
| 103 | + |
| 104 | + if kIsVarlen: |
| 105 | + out2 = torch.empty((1, H_Q, SQ), **options2) |
| 106 | + else: |
| 107 | + out2 = torch.empty((B, H_Q, SQ), **options2) |
| 108 | + return out1, out2 |
| 109 | + |
| 110 | + |
| 111 | +@torch.library.impl("blackwell_fmha::fmha_bwd", "cuda") |
| 112 | +def custom_op_fmha_bwd( |
| 113 | + dOutput: torch.Tensor, |
| 114 | + query: torch.Tensor, |
| 115 | + key: torch.Tensor, |
| 116 | + value: torch.Tensor, |
| 117 | + output: torch.Tensor, |
| 118 | + softmax_lse: torch.Tensor, |
| 119 | + cu_seqlens_q: torch.Tensor = None, |
| 120 | + cu_seqlens_k: torch.Tensor = None, |
| 121 | + max_seq_len_q: int = None, |
| 122 | + max_seq_len_k: int = None, |
| 123 | + causal: bool = False, |
| 124 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 125 | + return torch.ops.fbgemm.fmha_bwd( |
| 126 | + dOutput, |
| 127 | + query, |
| 128 | + key, |
| 129 | + value, |
| 130 | + output, |
| 131 | + softmax_lse, |
| 132 | + cu_seqlens_q=cu_seqlens_q, |
| 133 | + cu_seqlens_k=cu_seqlens_k, |
| 134 | + max_seq_len_q=max_seq_len_q, |
| 135 | + max_seq_len_k=max_seq_len_k, |
| 136 | + causal=causal, |
| 137 | + ) |
| 138 | + |
| 139 | + |
| 140 | +@register_fake("blackwell_fmha::fmha_bwd") |
| 141 | +def fmha_bwd_meta( |
| 142 | + dOutput: torch.Tensor, |
| 143 | + query: torch.Tensor, |
| 144 | + key: torch.Tensor, |
| 145 | + value: torch.Tensor, |
| 146 | + output: torch.Tensor, |
| 147 | + softmax_lse: torch.Tensor, |
| 148 | + cu_seqlens_q: torch.Tensor = None, |
| 149 | + cu_seqlens_k: torch.Tensor = None, |
| 150 | + max_seq_len_q: int = None, |
| 151 | + max_seq_len_k: int = None, |
| 152 | + causal: bool = False, |
| 153 | +): |
| 154 | + return ( |
| 155 | + torch.empty_like(query), |
| 156 | + torch.empty_like(key), |
| 157 | + torch.empty_like(value), |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +def _backward(ctx, *grad): |
| 162 | + if ctx.is_gen: |
| 163 | + # For gen case, no backward pass is needed (generation is inference only) |
| 164 | + raise RuntimeError("Backward pass is not supported for generation phase (sq=1)") |
| 165 | + q, k, v, out, softmax_lse = ctx.saved_tensors |
| 166 | + if not grad[0].is_contiguous(): |
| 167 | + grad0 = grad[0].contiguous() |
| 168 | + else: |
| 169 | + grad0 = grad[0] |
| 170 | + if not softmax_lse.is_contiguous: |
| 171 | + softmax_lse = softmax_lse.contiguous() |
| 172 | + if not out.is_contiguous: |
| 173 | + out = out.contiguous() |
| 174 | + if not q.is_contiguous: |
| 175 | + q = q.contiguous() |
| 176 | + if not k.is_contiguous: |
| 177 | + k = k.contiguous() |
| 178 | + |
| 179 | + if not softmax_lse.is_contiguous: |
| 180 | + softmax_lse = softmax_lse.contiguous() |
| 181 | + if not out.is_contiguous: |
| 182 | + out = out.contiguous() |
| 183 | + if not q.is_contiguous: |
| 184 | + q = q.contiguous() |
| 185 | + if not k.is_contiguous: |
| 186 | + k = k.contiguous() |
| 187 | + |
| 188 | + dq, dk, dv = torch.ops.blackwell_fmha.fmha_bwd( |
| 189 | + grad0, |
| 190 | + q, |
| 191 | + k, |
| 192 | + v, |
| 193 | + out, |
| 194 | + softmax_lse, |
| 195 | + ctx.cu_seqlens_q, |
| 196 | + ctx.cu_seqlens_k, |
| 197 | + ctx.max_seq_len_q, |
| 198 | + ctx.max_seq_len_k, |
| 199 | + ctx.causal, |
| 200 | + ) |
| 201 | + return dq, dk, dv, None, None, None, None, None, None, None |
| 202 | + |
| 203 | + |
| 204 | +def _setup_context(ctx, inputs, output): |
| 205 | + ( |
| 206 | + q, |
| 207 | + k, |
| 208 | + v, |
| 209 | + cu_seqlens_q, |
| 210 | + cu_seqlens_k, |
| 211 | + max_seq_len_q, |
| 212 | + max_seq_len_k, |
| 213 | + softmax_scale, |
| 214 | + causal, |
| 215 | + seqlen_kv, |
| 216 | + ) = inputs |
| 217 | + (out, softmax_lse) = output |
| 218 | + ctx.save_for_backward(q, k, v, out, softmax_lse) |
| 219 | + ctx.softmax_scale = softmax_scale |
| 220 | + ctx.causal = causal |
| 221 | + ctx.max_seq_len_q = max_seq_len_q |
| 222 | + ctx.max_seq_len_k = max_seq_len_k |
| 223 | + ctx.cu_seqlens_q = cu_seqlens_q |
| 224 | + ctx.cu_seqlens_k = cu_seqlens_k |
| 225 | + ctx.is_gen = False |
| 226 | + |
| 227 | + |
| 228 | +# This code adds training support for the operator. You must provide us |
| 229 | +# the backward formula for the operator and a `setup_context` function |
| 230 | +# to save values to be used in the backward. |
| 231 | +torch.library.register_autograd( |
| 232 | + "blackwell_fmha::fmha_fwd", _backward, setup_context=_setup_context |
| 233 | +) |
| 234 | + |
| 235 | + |
| 236 | +def cutlass_blackwell_fmha_custom_op( |
| 237 | + q: torch.Tensor, |
| 238 | + k: torch.Tensor, |
| 239 | + v: torch.Tensor, |
| 240 | + softmax_scale: Optional[float] = None, |
| 241 | + causal: bool = False, |
| 242 | + cu_seqlens_q: Optional[torch.Tensor] = None, |
| 243 | + cu_seqlens_k: Optional[torch.Tensor] = None, |
| 244 | + max_seq_len_q: Optional[int] = None, |
| 245 | + max_seq_len_k: Optional[int] = None, |
| 246 | + seqlen_kv: Optional[torch.Tensor] = None, |
| 247 | +): |
| 248 | + return torch.ops.blackwell_fmha.fmha_fwd( |
| 249 | + q=q, |
| 250 | + k=k, |
| 251 | + v=v, |
| 252 | + cu_seqlens_q=cu_seqlens_q, |
| 253 | + cu_seqlens_k=cu_seqlens_k, |
| 254 | + max_seq_len_q=max_seq_len_q, |
| 255 | + max_seq_len_k=max_seq_len_k, |
| 256 | + softmax_scale=softmax_scale, |
| 257 | + causal=causal, |
| 258 | + seqlen_kv=seqlen_kv, |
| 259 | + )[0] |
0 commit comments