|
| 1 | +// @nolint |
| 2 | +#include "blackwell_fmha_utils.hpp" |
| 3 | +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) |
| 4 | + |
| 5 | +template < |
| 6 | + typename Element, |
| 7 | + typename ActiveMask, |
| 8 | + bool kIsVarlen, |
| 9 | + class... KernelOptions> |
| 10 | +std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd( |
| 11 | + const at::Tensor& dO, |
| 12 | + const at::Tensor& q, |
| 13 | + const at::Tensor& k, |
| 14 | + const at::Tensor& v, |
| 15 | + const at::Tensor& o, |
| 16 | + const at::Tensor& softmax_lse, |
| 17 | + const std::optional<const at::Tensor>& cu_seqlens_q, |
| 18 | + const std::optional<const at::Tensor>& cu_seqlens_k, |
| 19 | + std::optional<int> max_seq_len_q, |
| 20 | + std::optional<int> max_seq_len_k, |
| 21 | + int64_t window_size_left, |
| 22 | + int64_t window_size_right |
| 23 | +) { |
| 24 | + const auto device = q.device(); |
| 25 | + at::cuda::CUDAGuard device_guard(device); |
| 26 | + |
| 27 | + using ElementAccumulator = float; |
| 28 | + |
| 29 | + // Q K D D_VO ((H_R, H_K) B) |
| 30 | + using ProblemShapeType = std::conditional_t< |
| 31 | + kIsVarlen, |
| 32 | + cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>, |
| 33 | + cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>> |
| 34 | + >; |
| 35 | + |
| 36 | + using TileShape = Shape<_128, _128, _128>; |
| 37 | + |
| 38 | + using Operation = cutlass::fmha::device:: |
| 39 | + Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /*kIsMla=*/false, ActiveMask>; |
| 40 | + |
| 41 | + using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B) |
| 42 | + using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B) |
| 43 | + using StrideV = StrideK; // K D_VO ((H_R, H_K), B) |
| 44 | + using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B) |
| 45 | + using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B) |
| 46 | + |
| 47 | + // Backwards specific |
| 48 | + using StrideDQ = StrideQ; |
| 49 | + using StrideDK = StrideK; |
| 50 | + using StrideDV = StrideV; |
| 51 | + using StrideDO = StrideO; |
| 52 | + |
| 53 | + if (kIsVarlen) { |
| 54 | + TORCH_CHECK( |
| 55 | + q.dim() == 3, |
| 56 | + "Expect Q shape to be (total_Q_seqlen, num_Q_heads, head_dim) ", |
| 57 | + "Found shape ", q.sizes()); |
| 58 | + TORCH_CHECK( |
| 59 | + k.dim() == 3, |
| 60 | + "Expect K shape to be (total_KV_seqlen, num_KV_heads, head_dim) ", |
| 61 | + "Found shape ", k.sizes()); |
| 62 | + TORCH_CHECK( |
| 63 | + v.dim() == 3, |
| 64 | + "Expect V shape to be (total_KV_seqlen, num_KV_heads, head_dim) ", |
| 65 | + "Found shape ", v.sizes()); |
| 66 | + } |
| 67 | + else { |
| 68 | + TORCH_CHECK( |
| 69 | + q.dim() == 4, |
| 70 | + "Expect Q shape to be (batch_size, Q_seqlen, num_Q_heads, head_dim). ", |
| 71 | + "Found shape ", q.sizes()); |
| 72 | + TORCH_CHECK( |
| 73 | + k.dim() == 4, |
| 74 | + "Expect K shape to be (batch_size, KV_seqlen, num_KV_heads, head_dim) ", |
| 75 | + "Found shape ", k.sizes()); |
| 76 | + TORCH_CHECK( |
| 77 | + v.dim() == 4, |
| 78 | + "Expect V shape to be (batch_size, KV_seqlen, num_KV_heads, head_dim) ", |
| 79 | + "Found shape ", v.sizes()); |
| 80 | + } |
| 81 | + |
| 82 | + if constexpr (kIsVarlen) { |
| 83 | + TORCH_CHECK(cu_seqlens_q.has_value(), "cu_seqlens_q should be set"); |
| 84 | + TORCH_CHECK(cu_seqlens_k.has_value(), "cu_seqlens_k should be set"); |
| 85 | + TORCH_CHECK(max_seq_len_q.has_value(), "max_seq_len_q should be set"); |
| 86 | + TORCH_CHECK(max_seq_len_k.has_value(), "max_seq_len_k should be set"); |
| 87 | + } |
| 88 | + |
| 89 | + int B = kIsVarlen ? cu_seqlens_q->size(0) - 1 : q.size(0); |
| 90 | + // Q represents SumB(Q) for varlen (jagged len) |
| 91 | + int Q = kIsVarlen ? q.size(0) : q.size(1); |
| 92 | + int K = kIsVarlen ? k.size(0) : k.size(1); |
| 93 | + int H_Q = kIsVarlen ? q.size(1) : q.size(2); |
| 94 | + int H_K = kIsVarlen ? k.size(1) : k.size(2); |
| 95 | + int D = q.size(q.dim() - 1); // Head dimension (D) |
| 96 | + |
| 97 | + TORCH_CHECK(H_Q % H_K == 0, "Q heads must be a multiple of KV heads"); |
| 98 | + int H_R = H_Q / H_K; |
| 99 | + |
| 100 | + ProblemShapeType problem_shape; |
| 101 | + if constexpr (kIsVarlen) { |
| 102 | + problem_shape = cute::make_tuple( |
| 103 | + VariableLength{ |
| 104 | + *max_seq_len_q, static_cast<int*>(cu_seqlens_q->data_ptr()), int(q.size(0))}, |
| 105 | + VariableLength{ |
| 106 | + *max_seq_len_k, static_cast<int*>(cu_seqlens_k->data_ptr()), int(k.size(0))}, |
| 107 | + D, |
| 108 | + D, |
| 109 | + make_shape(make_shape(H_R, H_K), B)); |
| 110 | + } |
| 111 | + else { |
| 112 | + problem_shape = cute::make_tuple( |
| 113 | + Q, K, D, D, make_shape(make_shape(H_R, H_K), B)); |
| 114 | + } |
| 115 | + |
| 116 | + TORCH_CHECK(D % 8 == 0); // Alignment |
| 117 | + if constexpr (!kIsVarlen) { |
| 118 | + TORCH_CHECK(Q % 8 == 0); // Alignment |
| 119 | + } |
| 120 | + |
| 121 | + // Reshape to get strides |
| 122 | + auto B_ = kIsVarlen ? 1 : B; |
| 123 | + auto q_ = q.reshape({B_, Q, H_K, H_R, D}); |
| 124 | + auto k_ = k.reshape({B_, K, H_K, 1, D}).expand({B_, K, H_K, H_R, D}); |
| 125 | + auto lse_ = softmax_lse.reshape({B_, H_K, H_R, Q}); |
| 126 | + auto ndim = q_.dim(); |
| 127 | + |
| 128 | + TORCH_CHECK(q_.stride(ndim - 1) == 1, "The head dim in Q must be contiguous"); |
| 129 | + TORCH_CHECK(k_.stride(ndim - 1) == 1, "The head dim in KV must be contiguous"); |
| 130 | + if (H_R != 1) { |
| 131 | + TORCH_CHECK(k_.stride(3) == 0, "The shared KV head stride must be zero"); |
| 132 | + } |
| 133 | + |
| 134 | + // Note: We use a different layout from 77_blackwell_fmha_bwd.cu. |
| 135 | + // Q shape = (B, Q, H_K, H_R, D) |
| 136 | + StrideQ stride_Q = make_stride( |
| 137 | + static_cast<int>(q_.stride(1)), _1{}, |
| 138 | + make_stride( |
| 139 | + make_stride(static_cast<int>(q_.stride(3)), static_cast<int>(q_.stride(2))), |
| 140 | + static_cast<int>(q_.stride(0)))); |
| 141 | + |
| 142 | + // K shape = (B, K, H_K, 1, D) |
| 143 | + StrideK stride_K = make_stride( |
| 144 | + static_cast<int>(k_.stride(1)), _1{}, |
| 145 | + make_stride( |
| 146 | + make_stride(_0{}, static_cast<int>(k_.stride(2))), |
| 147 | + static_cast<int>(k_.stride(0)))); |
| 148 | + |
| 149 | + // LSE shape = (B, H_K, H_R, Q) |
| 150 | + StrideLSE stride_LSE = make_stride( |
| 151 | + _1{}, |
| 152 | + make_stride( |
| 153 | + make_stride(static_cast<int>(lse_.stride(2)), static_cast<int>(lse_.stride(1))), |
| 154 | + static_cast<int>(lse_.stride(0)))); |
| 155 | + StrideV stride_V = stride_K; |
| 156 | + StrideO stride_O = stride_Q; |
| 157 | + |
| 158 | + if constexpr (kIsVarlen) { |
| 159 | + get<2, 1>(stride_Q) = 0; |
| 160 | + get<2, 1>(stride_K) = 0; |
| 161 | + get<2, 1>(stride_V) = 0; |
| 162 | + get<2, 1>(stride_O) = 0; |
| 163 | + get<1, 1>(stride_LSE) = 0; |
| 164 | + } |
| 165 | + |
| 166 | + StrideDQ stride_dQ = stride_Q; |
| 167 | + StrideDK stride_dK = stride_K; |
| 168 | + StrideDV stride_dV = stride_V; |
| 169 | + StrideDO stride_dO = stride_O; |
| 170 | + |
| 171 | + // TODO: pass in softmax_scale? |
| 172 | + ElementAccumulator softmax_scale = 1.0f / sqrtf(D); |
| 173 | + |
| 174 | + at::Tensor dQ = torch::empty_like(q); |
| 175 | + at::Tensor dK = torch::empty_like(k); |
| 176 | + at::Tensor dV = torch::empty_like(v); |
| 177 | + |
| 178 | + cutlass::KernelHardwareInfo hw_info; |
| 179 | + hw_info.device_id = device.index(); |
| 180 | + hw_info.sm_count = |
| 181 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( |
| 182 | + hw_info.device_id); |
| 183 | + |
| 184 | + typename Operation::Arguments arguments{ |
| 185 | + problem_shape, |
| 186 | + static_cast<Element*>(q.data_ptr()), |
| 187 | + stride_Q, |
| 188 | + static_cast<Element*>(k.data_ptr()), |
| 189 | + stride_K, |
| 190 | + static_cast<Element*>(v.data_ptr()), |
| 191 | + stride_V, |
| 192 | + static_cast<Element*>(o.data_ptr()), |
| 193 | + stride_O, |
| 194 | + static_cast<ElementAccumulator*>(softmax_lse.data_ptr()), |
| 195 | + stride_LSE, |
| 196 | + static_cast<Element*>(dO.data_ptr()), |
| 197 | + stride_dO, |
| 198 | + static_cast<Element*>(dQ.data_ptr()), |
| 199 | + stride_dQ, |
| 200 | + static_cast<Element*>(dK.data_ptr()), |
| 201 | + stride_dK, |
| 202 | + static_cast<Element*>(dV.data_ptr()), |
| 203 | + stride_dV, |
| 204 | + softmax_scale, |
| 205 | + hw_info}; |
| 206 | + launch_fmha_op<Operation>(arguments); |
| 207 | + |
| 208 | + return std::make_tuple(dQ, dK, dV); |
| 209 | +} |
| 210 | + |
| 211 | +struct KernelCoop {}; |
| 212 | + |
| 213 | +std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd( |
| 214 | + const at::Tensor& dOutput, |
| 215 | + const at::Tensor& query, |
| 216 | + const at::Tensor& key, |
| 217 | + const at::Tensor& value, |
| 218 | + const at::Tensor& output, |
| 219 | + const at::Tensor& softmax_lse, |
| 220 | + const std::optional<at::Tensor>& cu_seqlens_q, |
| 221 | + const std::optional<at::Tensor>& cu_seqlens_k, |
| 222 | + std::optional<int64_t> max_seq_len_q, |
| 223 | + std::optional<int64_t> max_seq_len_k, |
| 224 | + bool causal, |
| 225 | + int64_t window_size_left, |
| 226 | + int64_t window_size_right |
| 227 | +) { |
| 228 | + |
| 229 | + TORCH_CHECK(dOutput.is_contiguous()); |
| 230 | + TORCH_CHECK(query.is_contiguous()); |
| 231 | + TORCH_CHECK(key.is_contiguous()); |
| 232 | + TORCH_CHECK(value.is_contiguous()); |
| 233 | + TORCH_CHECK(output.is_contiguous()); |
| 234 | + TORCH_CHECK(softmax_lse.is_contiguous()); |
| 235 | + |
| 236 | + // This workaround initializes the CUDA context to prevent the 201 error |
| 237 | + // (invalid context). When this function is invoked through PyTorch |
| 238 | + // autograd, it runs on a new thread that hasn't been associated with a CUDA |
| 239 | + // context. To bind this thread to a CUDA context, we call a CUDA runtime API |
| 240 | + // (e.g., cudaFree), which will automatically initialize the context. This |
| 241 | + // ensures that subsequent calls to driver APIs, which assume an initialized |
| 242 | + // CUDA context, do not result in an invalid context error. |
| 243 | + // TODO: initialize context properly |
| 244 | + cudaFree(0); |
| 245 | + |
| 246 | + // Handle local attention parameters |
| 247 | + bool local = (window_size_left >= 0 || window_size_right >= 0); |
| 248 | + if (local) { |
| 249 | + // If causal is enabled, override window_size_right to 0 for causal+local behavior |
| 250 | + if (causal) { |
| 251 | + window_size_right = 0; |
| 252 | + causal = false; // Use local attention instead of causal |
| 253 | + } |
| 254 | + // Expand -1 window sizes to full sequence length if available |
| 255 | + if (window_size_left < 0 && max_seq_len_k.has_value()) { |
| 256 | + window_size_left = max_seq_len_k.value(); |
| 257 | + } |
| 258 | + if (window_size_right < 0 && max_seq_len_k.has_value()) { |
| 259 | + window_size_right = max_seq_len_k.value(); |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + |
| 264 | + auto dispatch_fmha = |
| 265 | + [&](auto element, auto element_out, auto varlen, auto mask, auto... kernel_options) { |
| 266 | + return fmha_bwd< |
| 267 | + decltype(element), |
| 268 | + decltype(mask), |
| 269 | + varlen, |
| 270 | + decltype(kernel_options)...> |
| 271 | + ( |
| 272 | + dOutput, |
| 273 | + query, |
| 274 | + key, |
| 275 | + value, |
| 276 | + output, |
| 277 | + softmax_lse, |
| 278 | + cu_seqlens_q, |
| 279 | + cu_seqlens_k, |
| 280 | + max_seq_len_q, |
| 281 | + max_seq_len_k, |
| 282 | + window_size_left, |
| 283 | + window_size_right); |
| 284 | + }; |
| 285 | + |
| 286 | + auto dispatch_type = [&](auto varlen, auto mask) { |
| 287 | + if (query.dtype() == torch::kFloat16) { |
| 288 | + return dispatch_fmha(cutlass::half_t{}, cutlass::half_t{}, varlen, mask); |
| 289 | + } |
| 290 | + else if (query.dtype() == torch::kBFloat16) { |
| 291 | + return dispatch_fmha( |
| 292 | + cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, varlen, mask); |
| 293 | + } |
| 294 | + else if (query.dtype() == torch::kFloat8_e4m3fn) { |
| 295 | + return dispatch_fmha( |
| 296 | + cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, varlen, mask); |
| 297 | + } |
| 298 | + TORCH_CHECK(false, "Unsupported dtype for q: ", query.dtype()); |
| 299 | + }; |
| 300 | + |
| 301 | + auto dispatch_mask = [&](auto varlen) { |
| 302 | + if (causal) { |
| 303 | + return dispatch_type(varlen, CausalForBackwardMask</*kIsQBegin=*/false>{}); |
| 304 | + } |
| 305 | + else if (varlen || key.size(1) % 128 != 0) { |
| 306 | + // Use the residual mask for varlen or when K seqlen is not multiple of |
| 307 | + // blockN |
| 308 | + return dispatch_type(varlen, ResidualMaskForBackward{}); |
| 309 | + } |
| 310 | + else { |
| 311 | + return dispatch_type(varlen, NoMask{}); |
| 312 | + } |
| 313 | + }; |
| 314 | + |
| 315 | + if (max_seq_len_q.has_value()) { |
| 316 | + return dispatch_mask(std::bool_constant<true>{}); |
| 317 | + } else { |
| 318 | + TORCH_CHECK(query.dim() == 4, "q must be [B, M, H, D] for fixed length") |
| 319 | + return dispatch_mask(std::bool_constant<false>{}); |
| 320 | + } |
| 321 | +} |
| 322 | + |
| 323 | +// ------------------------------------------------------------------------------------------------- |
| 324 | +// Op registration |
| 325 | +// ------------------------------------------------------------------------------------------------- |
| 326 | +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { |
| 327 | + m.def("fmha_bwd(" |
| 328 | + " Tensor dOutput, " |
| 329 | + " Tensor query, " |
| 330 | + " Tensor key, " |
| 331 | + " Tensor value, " |
| 332 | + " Tensor output, " |
| 333 | + " Tensor softmax_lse, " |
| 334 | + " Tensor? cu_seqlens_q=None, " |
| 335 | + " Tensor? cu_seqlens_k=None, " |
| 336 | + " int? max_seq_len_q=None, " |
| 337 | + " int? max_seq_len_k=None, " |
| 338 | + " bool causal=False, " |
| 339 | + " int window_size_left=-1, " |
| 340 | + " int window_size_right=-1" |
| 341 | + ") -> (Tensor, Tensor, Tensor)" |
| 342 | + ); |
| 343 | +} |
| 344 | + |
| 345 | +TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { |
| 346 | + m.impl("fmha_bwd", dispatch_fmha_bwd); |
| 347 | +} |
| 348 | +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED |
0 commit comments