Skip to content

Commit 8c161d3

Browse files
sryapfacebook-github-bot
authored andcommitted
Split the Cutlass src file into forward and backward files (#4786)
Summary: X-link: facebookresearch/FBGEMM#1809 Pull Request resolved: #4786 Split the source file to improve the build speed Roughly improved the build time from ~176s -> ~135s (30%) Reviewed By: y-sq, jianyuh, q10 Differential Revision: D81201719
1 parent 124975c commit 8c161d3

File tree

4 files changed

+853
-830
lines changed

4 files changed

+853
-830
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
// @nolint
2+
#include "blackwell_fmha_utils.hpp"
3+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
4+
5+
template <
6+
typename Element,
7+
typename ActiveMask,
8+
bool kIsVarlen,
9+
class... KernelOptions>
10+
std::tuple<at::Tensor, at::Tensor, at::Tensor> fmha_bwd(
11+
const at::Tensor& dO,
12+
const at::Tensor& q,
13+
const at::Tensor& k,
14+
const at::Tensor& v,
15+
const at::Tensor& o,
16+
const at::Tensor& softmax_lse,
17+
const std::optional<const at::Tensor>& cu_seqlens_q,
18+
const std::optional<const at::Tensor>& cu_seqlens_k,
19+
std::optional<int> max_seq_len_q,
20+
std::optional<int> max_seq_len_k,
21+
int64_t window_size_left,
22+
int64_t window_size_right
23+
) {
24+
const auto device = q.device();
25+
at::cuda::CUDAGuard device_guard(device);
26+
27+
using ElementAccumulator = float;
28+
29+
// Q K D D_VO ((H_R, H_K) B)
30+
using ProblemShapeType = std::conditional_t<
31+
kIsVarlen,
32+
cute::tuple<VariableLength, VariableLength, int, int, cute::tuple<cute::tuple<int, int>, int>>,
33+
cute::tuple<int, int, int, int, cute::tuple<cute::tuple<int, int>, int>>
34+
>;
35+
36+
using TileShape = Shape<_128, _128, _128>;
37+
38+
using Operation = cutlass::fmha::device::
39+
Sm100FmhaBwd<ProblemShapeType, Element, ElementAccumulator, TileShape, /*kIsMla=*/false, ActiveMask>;
40+
41+
using StrideQ = Stride<int, _1, Stride<Stride<int, int>, int>>; // Q D ((H_R, H_K), B)
42+
using StrideK = Stride<int, _1, Stride<Stride<_0, int>, int>>; // K D ((H_R, H_K), B)
43+
using StrideV = StrideK; // K D_VO ((H_R, H_K), B)
44+
using StrideO = StrideQ; // Q D_VO ((H_R, H_K), B)
45+
using StrideLSE = Stride<_1, Stride<Stride<int, int>, int>>; // Q ((H_R, H_K), B)
46+
47+
// Backwards specific
48+
using StrideDQ = StrideQ;
49+
using StrideDK = StrideK;
50+
using StrideDV = StrideV;
51+
using StrideDO = StrideO;
52+
53+
if (kIsVarlen) {
54+
TORCH_CHECK(
55+
q.dim() == 3,
56+
"Expect Q shape to be (total_Q_seqlen, num_Q_heads, head_dim) ",
57+
"Found shape ", q.sizes());
58+
TORCH_CHECK(
59+
k.dim() == 3,
60+
"Expect K shape to be (total_KV_seqlen, num_KV_heads, head_dim) ",
61+
"Found shape ", k.sizes());
62+
TORCH_CHECK(
63+
v.dim() == 3,
64+
"Expect V shape to be (total_KV_seqlen, num_KV_heads, head_dim) ",
65+
"Found shape ", v.sizes());
66+
}
67+
else {
68+
TORCH_CHECK(
69+
q.dim() == 4,
70+
"Expect Q shape to be (batch_size, Q_seqlen, num_Q_heads, head_dim). ",
71+
"Found shape ", q.sizes());
72+
TORCH_CHECK(
73+
k.dim() == 4,
74+
"Expect K shape to be (batch_size, KV_seqlen, num_KV_heads, head_dim) ",
75+
"Found shape ", k.sizes());
76+
TORCH_CHECK(
77+
v.dim() == 4,
78+
"Expect V shape to be (batch_size, KV_seqlen, num_KV_heads, head_dim) ",
79+
"Found shape ", v.sizes());
80+
}
81+
82+
if constexpr (kIsVarlen) {
83+
TORCH_CHECK(cu_seqlens_q.has_value(), "cu_seqlens_q should be set");
84+
TORCH_CHECK(cu_seqlens_k.has_value(), "cu_seqlens_k should be set");
85+
TORCH_CHECK(max_seq_len_q.has_value(), "max_seq_len_q should be set");
86+
TORCH_CHECK(max_seq_len_k.has_value(), "max_seq_len_k should be set");
87+
}
88+
89+
int B = kIsVarlen ? cu_seqlens_q->size(0) - 1 : q.size(0);
90+
// Q represents SumB(Q) for varlen (jagged len)
91+
int Q = kIsVarlen ? q.size(0) : q.size(1);
92+
int K = kIsVarlen ? k.size(0) : k.size(1);
93+
int H_Q = kIsVarlen ? q.size(1) : q.size(2);
94+
int H_K = kIsVarlen ? k.size(1) : k.size(2);
95+
int D = q.size(q.dim() - 1); // Head dimension (D)
96+
97+
TORCH_CHECK(H_Q % H_K == 0, "Q heads must be a multiple of KV heads");
98+
int H_R = H_Q / H_K;
99+
100+
ProblemShapeType problem_shape;
101+
if constexpr (kIsVarlen) {
102+
problem_shape = cute::make_tuple(
103+
VariableLength{
104+
*max_seq_len_q, static_cast<int*>(cu_seqlens_q->data_ptr()), int(q.size(0))},
105+
VariableLength{
106+
*max_seq_len_k, static_cast<int*>(cu_seqlens_k->data_ptr()), int(k.size(0))},
107+
D,
108+
D,
109+
make_shape(make_shape(H_R, H_K), B));
110+
}
111+
else {
112+
problem_shape = cute::make_tuple(
113+
Q, K, D, D, make_shape(make_shape(H_R, H_K), B));
114+
}
115+
116+
TORCH_CHECK(D % 8 == 0); // Alignment
117+
if constexpr (!kIsVarlen) {
118+
TORCH_CHECK(Q % 8 == 0); // Alignment
119+
}
120+
121+
// Reshape to get strides
122+
auto B_ = kIsVarlen ? 1 : B;
123+
auto q_ = q.reshape({B_, Q, H_K, H_R, D});
124+
auto k_ = k.reshape({B_, K, H_K, 1, D}).expand({B_, K, H_K, H_R, D});
125+
auto lse_ = softmax_lse.reshape({B_, H_K, H_R, Q});
126+
auto ndim = q_.dim();
127+
128+
TORCH_CHECK(q_.stride(ndim - 1) == 1, "The head dim in Q must be contiguous");
129+
TORCH_CHECK(k_.stride(ndim - 1) == 1, "The head dim in KV must be contiguous");
130+
if (H_R != 1) {
131+
TORCH_CHECK(k_.stride(3) == 0, "The shared KV head stride must be zero");
132+
}
133+
134+
// Note: We use a different layout from 77_blackwell_fmha_bwd.cu.
135+
// Q shape = (B, Q, H_K, H_R, D)
136+
StrideQ stride_Q = make_stride(
137+
static_cast<int>(q_.stride(1)), _1{},
138+
make_stride(
139+
make_stride(static_cast<int>(q_.stride(3)), static_cast<int>(q_.stride(2))),
140+
static_cast<int>(q_.stride(0))));
141+
142+
// K shape = (B, K, H_K, 1, D)
143+
StrideK stride_K = make_stride(
144+
static_cast<int>(k_.stride(1)), _1{},
145+
make_stride(
146+
make_stride(_0{}, static_cast<int>(k_.stride(2))),
147+
static_cast<int>(k_.stride(0))));
148+
149+
// LSE shape = (B, H_K, H_R, Q)
150+
StrideLSE stride_LSE = make_stride(
151+
_1{},
152+
make_stride(
153+
make_stride(static_cast<int>(lse_.stride(2)), static_cast<int>(lse_.stride(1))),
154+
static_cast<int>(lse_.stride(0))));
155+
StrideV stride_V = stride_K;
156+
StrideO stride_O = stride_Q;
157+
158+
if constexpr (kIsVarlen) {
159+
get<2, 1>(stride_Q) = 0;
160+
get<2, 1>(stride_K) = 0;
161+
get<2, 1>(stride_V) = 0;
162+
get<2, 1>(stride_O) = 0;
163+
get<1, 1>(stride_LSE) = 0;
164+
}
165+
166+
StrideDQ stride_dQ = stride_Q;
167+
StrideDK stride_dK = stride_K;
168+
StrideDV stride_dV = stride_V;
169+
StrideDO stride_dO = stride_O;
170+
171+
// TODO: pass in softmax_scale?
172+
ElementAccumulator softmax_scale = 1.0f / sqrtf(D);
173+
174+
at::Tensor dQ = torch::empty_like(q);
175+
at::Tensor dK = torch::empty_like(k);
176+
at::Tensor dV = torch::empty_like(v);
177+
178+
cutlass::KernelHardwareInfo hw_info;
179+
hw_info.device_id = device.index();
180+
hw_info.sm_count =
181+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
182+
hw_info.device_id);
183+
184+
typename Operation::Arguments arguments{
185+
problem_shape,
186+
static_cast<Element*>(q.data_ptr()),
187+
stride_Q,
188+
static_cast<Element*>(k.data_ptr()),
189+
stride_K,
190+
static_cast<Element*>(v.data_ptr()),
191+
stride_V,
192+
static_cast<Element*>(o.data_ptr()),
193+
stride_O,
194+
static_cast<ElementAccumulator*>(softmax_lse.data_ptr()),
195+
stride_LSE,
196+
static_cast<Element*>(dO.data_ptr()),
197+
stride_dO,
198+
static_cast<Element*>(dQ.data_ptr()),
199+
stride_dQ,
200+
static_cast<Element*>(dK.data_ptr()),
201+
stride_dK,
202+
static_cast<Element*>(dV.data_ptr()),
203+
stride_dV,
204+
softmax_scale,
205+
hw_info};
206+
launch_fmha_op<Operation>(arguments);
207+
208+
return std::make_tuple(dQ, dK, dV);
209+
}
210+
211+
struct KernelCoop {};
212+
213+
std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
214+
const at::Tensor& dOutput,
215+
const at::Tensor& query,
216+
const at::Tensor& key,
217+
const at::Tensor& value,
218+
const at::Tensor& output,
219+
const at::Tensor& softmax_lse,
220+
const std::optional<at::Tensor>& cu_seqlens_q,
221+
const std::optional<at::Tensor>& cu_seqlens_k,
222+
std::optional<int64_t> max_seq_len_q,
223+
std::optional<int64_t> max_seq_len_k,
224+
bool causal,
225+
int64_t window_size_left,
226+
int64_t window_size_right
227+
) {
228+
229+
TORCH_CHECK(dOutput.is_contiguous());
230+
TORCH_CHECK(query.is_contiguous());
231+
TORCH_CHECK(key.is_contiguous());
232+
TORCH_CHECK(value.is_contiguous());
233+
TORCH_CHECK(output.is_contiguous());
234+
TORCH_CHECK(softmax_lse.is_contiguous());
235+
236+
// This workaround initializes the CUDA context to prevent the 201 error
237+
// (invalid context). When this function is invoked through PyTorch
238+
// autograd, it runs on a new thread that hasn't been associated with a CUDA
239+
// context. To bind this thread to a CUDA context, we call a CUDA runtime API
240+
// (e.g., cudaFree), which will automatically initialize the context. This
241+
// ensures that subsequent calls to driver APIs, which assume an initialized
242+
// CUDA context, do not result in an invalid context error.
243+
// TODO: initialize context properly
244+
cudaFree(0);
245+
246+
// Handle local attention parameters
247+
bool local = (window_size_left >= 0 || window_size_right >= 0);
248+
if (local) {
249+
// If causal is enabled, override window_size_right to 0 for causal+local behavior
250+
if (causal) {
251+
window_size_right = 0;
252+
causal = false; // Use local attention instead of causal
253+
}
254+
// Expand -1 window sizes to full sequence length if available
255+
if (window_size_left < 0 && max_seq_len_k.has_value()) {
256+
window_size_left = max_seq_len_k.value();
257+
}
258+
if (window_size_right < 0 && max_seq_len_k.has_value()) {
259+
window_size_right = max_seq_len_k.value();
260+
}
261+
}
262+
263+
264+
auto dispatch_fmha =
265+
[&](auto element, auto element_out, auto varlen, auto mask, auto... kernel_options) {
266+
return fmha_bwd<
267+
decltype(element),
268+
decltype(mask),
269+
varlen,
270+
decltype(kernel_options)...>
271+
(
272+
dOutput,
273+
query,
274+
key,
275+
value,
276+
output,
277+
softmax_lse,
278+
cu_seqlens_q,
279+
cu_seqlens_k,
280+
max_seq_len_q,
281+
max_seq_len_k,
282+
window_size_left,
283+
window_size_right);
284+
};
285+
286+
auto dispatch_type = [&](auto varlen, auto mask) {
287+
if (query.dtype() == torch::kFloat16) {
288+
return dispatch_fmha(cutlass::half_t{}, cutlass::half_t{}, varlen, mask);
289+
}
290+
else if (query.dtype() == torch::kBFloat16) {
291+
return dispatch_fmha(
292+
cutlass::bfloat16_t{}, cutlass::bfloat16_t{}, varlen, mask);
293+
}
294+
else if (query.dtype() == torch::kFloat8_e4m3fn) {
295+
return dispatch_fmha(
296+
cutlass::float_e4m3_t{}, cutlass::bfloat16_t{}, varlen, mask);
297+
}
298+
TORCH_CHECK(false, "Unsupported dtype for q: ", query.dtype());
299+
};
300+
301+
auto dispatch_mask = [&](auto varlen) {
302+
if (causal) {
303+
return dispatch_type(varlen, CausalForBackwardMask</*kIsQBegin=*/false>{});
304+
}
305+
else if (varlen || key.size(1) % 128 != 0) {
306+
// Use the residual mask for varlen or when K seqlen is not multiple of
307+
// blockN
308+
return dispatch_type(varlen, ResidualMaskForBackward{});
309+
}
310+
else {
311+
return dispatch_type(varlen, NoMask{});
312+
}
313+
};
314+
315+
if (max_seq_len_q.has_value()) {
316+
return dispatch_mask(std::bool_constant<true>{});
317+
} else {
318+
TORCH_CHECK(query.dim() == 4, "q must be [B, M, H, D] for fixed length")
319+
return dispatch_mask(std::bool_constant<false>{});
320+
}
321+
}
322+
323+
// -------------------------------------------------------------------------------------------------
324+
// Op registration
325+
// -------------------------------------------------------------------------------------------------
326+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
327+
m.def("fmha_bwd("
328+
" Tensor dOutput, "
329+
" Tensor query, "
330+
" Tensor key, "
331+
" Tensor value, "
332+
" Tensor output, "
333+
" Tensor softmax_lse, "
334+
" Tensor? cu_seqlens_q=None, "
335+
" Tensor? cu_seqlens_k=None, "
336+
" int? max_seq_len_q=None, "
337+
" int? max_seq_len_k=None, "
338+
" bool causal=False, "
339+
" int window_size_left=-1, "
340+
" int window_size_right=-1"
341+
") -> (Tensor, Tensor, Tensor)"
342+
);
343+
}
344+
345+
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
346+
m.impl("fmha_bwd", dispatch_fmha_bwd);
347+
}
348+
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED

0 commit comments

Comments
 (0)