Skip to content

Commit 49d9292

Browse files
sryapfacebook-github-bot
authored andcommitted
Move Cutlass kernels to FBGEMM Gen-AI
Summary: As title Differential Revision: D81189637
1 parent 472f4c4 commit 49d9292

File tree

49 files changed

+24756
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+24756
-0
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional, Tuple
8+
9+
import torch
10+
from torch.library import register_fake
11+
12+
13+
torch.library.define(
14+
"blackwell_fmha::fmha_fwd",
15+
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)",
16+
tags=[torch.Tag.pt2_compliant_tag],
17+
)
18+
19+
torch.library.define(
20+
"blackwell_fmha::fmha_bwd",
21+
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)",
22+
tags=[torch.Tag.pt2_compliant_tag],
23+
)
24+
25+
26+
@torch.library.impl("blackwell_fmha::fmha_fwd", "cuda")
27+
def custom_op_fmha(
28+
q: torch.Tensor,
29+
k: torch.Tensor,
30+
v: torch.Tensor,
31+
cu_seqlens_q: torch.Tensor = None,
32+
cu_seqlens_k: torch.Tensor = None,
33+
max_seq_len_q: int = None,
34+
max_seq_len_k: int = None,
35+
softmax_scale: float = None,
36+
causal: bool = False,
37+
seqlen_kv: torch.Tensor = None,
38+
) -> Tuple[torch.Tensor, torch.Tensor]:
39+
assert q.is_contiguous(), "q is not contiguous"
40+
assert k.is_contiguous(), "k is not contiguous"
41+
assert v.is_contiguous(), "v is not contiguous"
42+
assert q.is_cuda, "q must be on GPU"
43+
assert k.is_cuda, "k must be on GPU"
44+
assert v.is_cuda, "v must be on GPU"
45+
return torch.ops.fbgemm.fmha_fwd(
46+
q,
47+
k,
48+
v,
49+
cu_seqlens_q=cu_seqlens_q,
50+
cu_seqlens_k=cu_seqlens_k,
51+
max_seq_len_q=max_seq_len_q,
52+
max_seq_len_k=max_seq_len_k,
53+
softmax_scale=softmax_scale,
54+
causal=causal,
55+
seqlen_kv=seqlen_kv,
56+
)
57+
58+
59+
@register_fake("blackwell_fmha::fmha_fwd")
60+
def fmha_fwd_meta(
61+
q: torch.Tensor,
62+
k: torch.Tensor,
63+
v: torch.Tensor,
64+
cu_seqlens_q: torch.Tensor = None,
65+
cu_seqlens_k: torch.Tensor = None,
66+
max_seq_len_q: int = None,
67+
max_seq_len_k: int = None,
68+
softmax_scale: float = None,
69+
causal: bool = False,
70+
seqlen_kv: torch.Tensor = None,
71+
):
72+
if q.dtype == torch.float16:
73+
out_dtype = torch.float16
74+
elif q.dtype == torch.bfloat16:
75+
out_dtype = torch.bfloat16
76+
elif q.dtype == torch.float8_e4m3fn:
77+
# Output is BF16 when input is FP8
78+
out_dtype = torch.bfloat16
79+
else:
80+
raise RuntimeError(f"Unsupported dtype for q: {q.dtype}")
81+
82+
kIsVarlen = max_seq_len_q is not None
83+
if kIsVarlen:
84+
SQ = q.shape[0]
85+
H_Q = q.shape[1]
86+
B = cu_seqlens_q.shape[0] - 1
87+
else:
88+
SQ = q.shape[1]
89+
H_Q = q.shape[2]
90+
B = q.shape[0]
91+
device = q.device
92+
options2 = {"dtype": torch.float32, "device": device}
93+
if kIsVarlen:
94+
out = torch.empty_like(q, dtype=out_dtype)
95+
size = out.size()
96+
stride = out.stride()
97+
storage_offset = q.shape[-1] * max_seq_len_q * H_Q # example scalar offset
98+
out1 = torch.as_strided(
99+
out, size=size, stride=stride, storage_offset=storage_offset
100+
)
101+
else:
102+
out1 = torch.empty_like(q, dtype=out_dtype)
103+
104+
if kIsVarlen:
105+
out2 = torch.empty((1, H_Q, SQ), **options2)
106+
else:
107+
out2 = torch.empty((B, H_Q, SQ), **options2)
108+
return out1, out2
109+
110+
111+
@torch.library.impl("blackwell_fmha::fmha_bwd", "cuda")
112+
def custom_op_fmha_bwd(
113+
dOutput: torch.Tensor,
114+
query: torch.Tensor,
115+
key: torch.Tensor,
116+
value: torch.Tensor,
117+
output: torch.Tensor,
118+
softmax_lse: torch.Tensor,
119+
cu_seqlens_q: torch.Tensor = None,
120+
cu_seqlens_k: torch.Tensor = None,
121+
max_seq_len_q: int = None,
122+
max_seq_len_k: int = None,
123+
causal: bool = False,
124+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125+
return torch.ops.fbgemm.fmha_bwd(
126+
dOutput,
127+
query,
128+
key,
129+
value,
130+
output,
131+
softmax_lse,
132+
cu_seqlens_q=cu_seqlens_q,
133+
cu_seqlens_k=cu_seqlens_k,
134+
max_seq_len_q=max_seq_len_q,
135+
max_seq_len_k=max_seq_len_k,
136+
causal=causal,
137+
)
138+
139+
140+
@register_fake("blackwell_fmha::fmha_bwd")
141+
def fmha_bwd_meta(
142+
dOutput: torch.Tensor,
143+
query: torch.Tensor,
144+
key: torch.Tensor,
145+
value: torch.Tensor,
146+
output: torch.Tensor,
147+
softmax_lse: torch.Tensor,
148+
cu_seqlens_q: torch.Tensor = None,
149+
cu_seqlens_k: torch.Tensor = None,
150+
max_seq_len_q: int = None,
151+
max_seq_len_k: int = None,
152+
causal: bool = False,
153+
):
154+
return (
155+
torch.empty_like(query),
156+
torch.empty_like(key),
157+
torch.empty_like(value),
158+
)
159+
160+
161+
def _backward(ctx, *grad):
162+
if ctx.is_gen:
163+
# For gen case, no backward pass is needed (generation is inference only)
164+
raise RuntimeError("Backward pass is not supported for generation phase (sq=1)")
165+
q, k, v, out, softmax_lse = ctx.saved_tensors
166+
if not grad[0].is_contiguous():
167+
grad0 = grad[0].contiguous()
168+
else:
169+
grad0 = grad[0]
170+
if not softmax_lse.is_contiguous:
171+
softmax_lse = softmax_lse.contiguous()
172+
if not out.is_contiguous:
173+
out = out.contiguous()
174+
if not q.is_contiguous:
175+
q = q.contiguous()
176+
if not k.is_contiguous:
177+
k = k.contiguous()
178+
179+
if not softmax_lse.is_contiguous:
180+
softmax_lse = softmax_lse.contiguous()
181+
if not out.is_contiguous:
182+
out = out.contiguous()
183+
if not q.is_contiguous:
184+
q = q.contiguous()
185+
if not k.is_contiguous:
186+
k = k.contiguous()
187+
188+
dq, dk, dv = torch.ops.blackwell_fmha.fmha_bwd(
189+
grad0,
190+
q,
191+
k,
192+
v,
193+
out,
194+
softmax_lse,
195+
ctx.cu_seqlens_q,
196+
ctx.cu_seqlens_k,
197+
ctx.max_seq_len_q,
198+
ctx.max_seq_len_k,
199+
ctx.causal,
200+
)
201+
return dq, dk, dv, None, None, None, None, None, None, None
202+
203+
204+
def _setup_context(ctx, inputs, output):
205+
(
206+
q,
207+
k,
208+
v,
209+
cu_seqlens_q,
210+
cu_seqlens_k,
211+
max_seq_len_q,
212+
max_seq_len_k,
213+
softmax_scale,
214+
causal,
215+
seqlen_kv,
216+
) = inputs
217+
(out, softmax_lse) = output
218+
ctx.save_for_backward(q, k, v, out, softmax_lse)
219+
ctx.softmax_scale = softmax_scale
220+
ctx.causal = causal
221+
ctx.max_seq_len_q = max_seq_len_q
222+
ctx.max_seq_len_k = max_seq_len_k
223+
ctx.cu_seqlens_q = cu_seqlens_q
224+
ctx.cu_seqlens_k = cu_seqlens_k
225+
ctx.is_gen = False
226+
227+
228+
# This code adds training support for the operator. You must provide us
229+
# the backward formula for the operator and a `setup_context` function
230+
# to save values to be used in the backward.
231+
torch.library.register_autograd(
232+
"blackwell_fmha::fmha_fwd", _backward, setup_context=_setup_context
233+
)
234+
235+
236+
def cutlass_blackwell_fmha_custom_op(
237+
q: torch.Tensor,
238+
k: torch.Tensor,
239+
v: torch.Tensor,
240+
softmax_scale: Optional[float] = None,
241+
causal: bool = False,
242+
cu_seqlens_q: Optional[torch.Tensor] = None,
243+
cu_seqlens_k: Optional[torch.Tensor] = None,
244+
max_seq_len_q: Optional[int] = None,
245+
max_seq_len_k: Optional[int] = None,
246+
seqlen_kv: Optional[torch.Tensor] = None,
247+
):
248+
return torch.ops.blackwell_fmha.fmha_fwd(
249+
q=q,
250+
k=k,
251+
v=v,
252+
cu_seqlens_q=cu_seqlens_q,
253+
cu_seqlens_k=cu_seqlens_k,
254+
max_seq_len_q=max_seq_len_q,
255+
max_seq_len_k=max_seq_len_k,
256+
softmax_scale=softmax_scale,
257+
causal=causal,
258+
seqlen_kv=seqlen_kv,
259+
)[0]

0 commit comments

Comments
 (0)