Skip to content

Commit 705bff0

Browse files
q10facebook-github-bot
authored andcommitted
Enable HSTU builds in fbcode (#4290)
Summary: Pull Request resolved: #4290 X-link: facebookresearch/FBGEMM#1366 - Enable HSTU builds in fbcode Reviewed By: ionuthristodorescu Differential Revision: D76093631 fbshipit-source-id: e60e23c7ed7df1916661a17b519ca85e12cfeeaa
1 parent ba16adc commit 705bff0

File tree

8 files changed

+162
-112
lines changed

8 files changed

+162
-112
lines changed
Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#!/usr/bin/env python3
2-
# Copyright (c) 2024, NVIDIA Corporation & AFFILIATES.
3-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
43
# All rights reserved.
54
#
65
# This source code is licensed under the BSD-style license found in the
76
# LICENSE file in the root directory of this source tree.
87

8+
# Copyright (c) 2024, NVIDIA Corporation & AFFILIATES.
9+
910
# pyre-strict
1011

12+
import logging
1113
import os
1214

1315
import torch
@@ -23,15 +25,30 @@
2325
except Exception:
2426
open_source: bool = False
2527

26-
# pyre-ignore[16]
27-
if open_source:
28-
torch.ops.load_library(
29-
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so")
30-
)
31-
torch.classes.load_library(
32-
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so")
33-
)
28+
if (
29+
torch.cuda.is_available()
30+
and torch.version.cuda is not None
31+
and torch.version.cuda > "12.4"
32+
):
33+
if open_source:
34+
torch.ops.load_library(
35+
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so")
36+
)
37+
torch.classes.load_library(
38+
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_hstu.so")
39+
)
40+
else:
41+
torch.ops.load_library(
42+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/hstu/src:hstu_ops_gpu_sm80"
43+
)
44+
45+
if torch.cuda.get_device_capability() >= (9, 0):
46+
torch.ops.load_library(
47+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/hstu/src:hstu_ops_gpu_sm90"
48+
)
49+
3450
else:
35-
torch.ops.load_library(
36-
"//deeplearning/fbgemm/fbgemm_gpu/experimental/hstu:hstu_ops"
37-
)
51+
logging.warning("CUDA is not available for FBGEMM HSTU")
52+
53+
54+
from .cuda_hstu_attention import hstu_attn_varlen_func, HstuAttnVarlenFunc # noqa: F401

fbgemm_gpu/experimental/hstu/hstu/cuda_hstu_attention.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
#!/usr/bin/env python3
2-
# Copyright (c) 2024, NVIDIA Corporation & AFFILIATES.
3-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
43
# All rights reserved.
54
#
65
# This source code is licensed under the BSD-style license found in the
76
# LICENSE file in the root directory of this source tree.
87

8+
# Copyright (c) 2024, NVIDIA Corporation & AFFILIATES.
9+
910
# pyre-strict
1011

11-
from typing import Tuple
12+
from typing import Any, Optional, Tuple
1213

1314
import torch
1415

1516

1617
class HstuAttnVarlenFunc(torch.autograd.Function):
1718
@staticmethod
18-
def forward(
19-
ctx,
19+
def forward( # pyre-ignore[14]
20+
ctx, # pyre-ignore[2]
2021
q: torch.Tensor, # need grad
2122
k: torch.Tensor, # need grad
2223
v: torch.Tensor, # need grad
@@ -29,12 +30,12 @@ def forward(
2930
target_group_size: int,
3031
window_size: Tuple[int, int] = (-1, -1),
3132
alpha: float = 1.0,
32-
rab: torch.Tensor = None, # need grad
33+
rab: Optional[torch.Tensor] = None, # need grad
3334
has_drab: bool = False,
3435
is_delta_q: bool = False,
35-
descale_q: torch.Tensor = None,
36-
descale_k: torch.Tensor = None,
37-
descale_v: torch.Tensor = None,
36+
descale_q: Optional[torch.Tensor] = None,
37+
descale_k: Optional[torch.Tensor] = None,
38+
descale_v: Optional[torch.Tensor] = None,
3839
) -> torch.Tensor:
3940
assert q.dim() == 3, "q shape should be (L, num_heads, head_dim)"
4041
assert k.dim() == 3, "k shape should be (L, num_heads, head_dim)"
@@ -104,10 +105,10 @@ def forward(
104105
return out
105106

106107
@staticmethod
107-
def backward(
108-
ctx,
108+
def backward( # pyre-ignore[14]
109+
ctx, # pyre-ignore[2]
109110
dout: torch.Tensor,
110-
*args: any,
111+
*args: Any,
111112
) -> tuple[
112113
torch.Tensor,
113114
torch.Tensor,
@@ -214,6 +215,7 @@ def backward(
214215
)
215216

216217

218+
# pyre-ignore[3]
217219
def hstu_attn_varlen_func(
218220
q: torch.Tensor,
219221
k: torch.Tensor,
@@ -227,12 +229,12 @@ def hstu_attn_varlen_func(
227229
target_group_size: int = 1,
228230
window_size: Tuple[int, int] = (-1, -1),
229231
alpha: float = 1.0,
230-
rab: torch.Tensor = None,
232+
rab: Optional[torch.Tensor] = None,
231233
has_drab: bool = False,
232234
is_delta_q: bool = False,
233-
descale_q: torch.Tensor = None,
234-
descale_k: torch.Tensor = None,
235-
descale_v: torch.Tensor = None,
235+
descale_q: Optional[torch.Tensor] = None,
236+
descale_k: Optional[torch.Tensor] = None,
237+
descale_v: Optional[torch.Tensor] = None,
236238
):
237239
"""
238240
Arguments:

fbgemm_gpu/experimental/hstu/src/generate_kernels.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2-
# Copyright (c) Meta Platforms, Inc. and affiliates.
1+
#!/usr/bin/env python3
2+
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
9+
810
import itertools
911
import os
1012

1113

12-
def generate_kernels_ampere():
14+
def generate_kernels_ampere(install_dir: str):
15+
"""
16+
Generate HSTU kernels for Ampere architecture.
17+
"""
18+
1319
DTYPE_16 = ["bf16", "fp16"]
1420
HEAD_DIMENSIONS = [32, 64, 128, 256]
1521
RAB = ["", "_rab"]
@@ -30,6 +36,8 @@ def generate_kernels_ampere():
3036
"fp16": "cutlass::half_t",
3137
}
3238

39+
os.makedirs(install_dir, exist_ok=True)
40+
3341
ampere_fwd_file_head = """
3442
/*
3543
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
@@ -51,9 +59,7 @@ def generate_kernels_ampere():
5159
for hdim, dtype, rab, mask in itertools.product(
5260
HEAD_DIMENSIONS, DTYPE_16, RAB, MASK
5361
):
54-
file_name = (
55-
f"hstu_ampere/instantiations/hstu_fwd_hdim{hdim}_{dtype}{rab}{mask}_sm80.cu"
56-
)
62+
file_name = f"{install_dir}/hstu_fwd_hdim{hdim}_{dtype}{rab}{mask}_sm80.cu"
5763
if not os.path.exists(file_name):
5864
with open(file_name, "w") as f:
5965
f.write(
@@ -90,7 +96,7 @@ def generate_kernels_ampere():
9096
for hdim, dtype, rab_drab, mask in itertools.product(
9197
HEAD_DIMENSIONS, DTYPE_16, RAB_DRAB, MASK
9298
):
93-
file_name = f"hstu_ampere/instantiations/hstu_bwd_hdim{hdim}_{dtype}{rab_drab}{mask}_sm80.cu"
99+
file_name = f"{install_dir}/hstu_bwd_hdim{hdim}_{dtype}{rab_drab}{mask}_sm80.cu"
94100
if not os.path.exists(file_name):
95101
with open(file_name, "w") as f:
96102
f.write(
@@ -108,7 +114,11 @@ def generate_kernels_ampere():
108114
)
109115

110116

111-
def generate_kernels_hopper():
117+
def generate_kernels_hopper(install_dir: str):
118+
"""
119+
Generate HSTU kernels for Hopper architecture.
120+
"""
121+
112122
DTYPE_16 = ["bf16", "fp16"]
113123
HEAD_DIMENSIONS = [32, 64, 128, 256]
114124
RAB = ["", "_rab"]
@@ -130,6 +140,8 @@ def generate_kernels_hopper():
130140
"fp16": "cutlass::half_t",
131141
}
132142

143+
os.makedirs(install_dir, exist_ok=True)
144+
133145
hopper_fwd_file_head = """
134146
/*
135147
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
@@ -151,9 +163,7 @@ def generate_kernels_hopper():
151163
for hdim, dtype, rab, mask in itertools.product(
152164
HEAD_DIMENSIONS, DTYPE_16, RAB, MASK
153165
):
154-
file_name = (
155-
f"hstu_hopper/instantiations/hstu_fwd_hdim{hdim}_{dtype}{rab}{mask}_sm90.cu"
156-
)
166+
file_name = f"{install_dir}/hstu_fwd_hdim{hdim}_{dtype}{rab}{mask}_sm90.cu"
157167
if not os.path.exists(file_name):
158168
with open(file_name, "w") as f:
159169
f.write(
@@ -172,9 +182,7 @@ def generate_kernels_hopper():
172182
for hdim, rab, mask in itertools.product(HEAD_DIMENSIONS, RAB, FP8_MASK):
173183
if hdim == 32:
174184
continue
175-
file_name = (
176-
f"hstu_hopper/instantiations/hstu_fwd_hdim{hdim}_e4m3{rab}{mask}_sm90.cu"
177-
)
185+
file_name = f"{install_dir}/hstu_fwd_hdim{hdim}_e4m3{rab}{mask}_sm90.cu"
178186
if not os.path.exists(file_name):
179187
with open(file_name, "w") as f:
180188
f.write(
@@ -211,7 +219,7 @@ def generate_kernels_hopper():
211219
for hdim, dtype, rab_drab, mask in itertools.product(
212220
HEAD_DIMENSIONS, DTYPE_16, RAB_DRAB, MASK
213221
):
214-
file_name = f"hstu_hopper/instantiations/hstu_bwd_hdim{hdim}_{dtype}{rab_drab}{mask}_sm90.cu"
222+
file_name = f"{install_dir}/hstu_bwd_hdim{hdim}_{dtype}{rab_drab}{mask}_sm90.cu"
215223
if not os.path.exists(file_name):
216224
with open(file_name, "w") as f:
217225
f.write(
@@ -229,8 +237,7 @@ def generate_kernels_hopper():
229237
)
230238

231239

232-
if __name__ == "__main__":
233-
240+
def main() -> None:
234241
import argparse
235242

236243
parser = argparse.ArgumentParser()
@@ -240,10 +247,22 @@ def generate_kernels_hopper():
240247
default="8.0 9.0",
241248
help="Comma-separated list of CUDA architectures to generate kernels for",
242249
)
250+
parser.add_argument(
251+
"--install_dir",
252+
type=str,
253+
default=None,
254+
help="Output directory for generated source files",
255+
)
243256
args = parser.parse_args()
244257

245258
if "8.0" in args.arch_list:
246-
generate_kernels_ampere()
259+
# In OSS, the generated files will be written to hstu_ampere/instantiations
260+
generate_kernels_ampere(args.install_dir or "hstu_ampere/instantiations")
247261

248262
if "9.0" in args.arch_list:
249-
generate_kernels_hopper()
263+
# In OSS, the generated files will be written to hstu_hopper/instantiations
264+
generate_kernels_hopper(args.install_dir or "hstu_hopper/instantiations")
265+
266+
267+
if __name__ == "__main__":
268+
main()

fbgemm_gpu/experimental/hstu/src/hstu_ampere/hstu_ops_gpu.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
11
/*
2-
* Copyright (c) 2023, Tri Dao.
3-
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
4-
* Copyright (c) Meta Platforms, Inc. and affiliates.
2+
* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
53
* All rights reserved.
64
*
75
* This source code is licensed under the BSD-style license found in the
86
* LICENSE file in the root directory of this source tree.
97
*/
108

9+
/*
10+
* Copyright (c) 2023, Tri Dao.
11+
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
12+
*/
13+
1114
// Include these 2 headers instead of torch/extension.h since we don't need all
1215
// of the torch headers.
1316
#include <ATen/ATen.h>
14-
#include <torch/library.h>
15-
16-
#include "c10/core/ScalarType.h"
17-
1817
#include <ATen/core/op_registration/op_registration.h>
1918
#include <ATen/cuda/CUDAContext.h>
19+
#include <c10/core/ScalarType.h>
2020
#include <c10/cuda/CUDAGuard.h>
21-
#include <c10/util/Optional.h>
21+
#include <torch/library.h>
2222
#include <torch/nn/functional.h>
2323

24+
#include <optional>
25+
2426
#include "hstu.h"
2527
#include "static_switch.h"
2628

@@ -440,13 +442,13 @@ std::tuple<at::Tensor, at::Tensor> hstu_varlen_fwd_80(
440442
const at::Tensor& cu_seqlens_k, // b+1
441443
const int64_t max_seqlen_q,
442444
const int64_t max_seqlen_k,
443-
const c10::optional<at::Tensor>& num_contexts, // b
444-
const c10::optional<at::Tensor>& num_targets, // b
445+
const std::optional<at::Tensor>& num_contexts, // b
446+
const std::optional<at::Tensor>& num_targets, // b
445447
const int64_t target_group_size,
446448
int64_t window_size_left,
447449
int64_t window_size_right,
448450
const double alpha,
449-
c10::optional<at::Tensor> rab,
451+
std::optional<at::Tensor> rab,
450452
const bool is_delta_q) {
451453
auto dprops = at::cuda::getCurrentDeviceProperties();
452454
TORCH_CHECK(dprops->major >= 8, "HSTU only supports Ampere GPUs or newer.");
@@ -738,13 +740,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> hstu_varlen_bwd_80(
738740
const at::Tensor& cu_seqlens_k, // b+1
739741
const int64_t max_seqlen_q,
740742
const int64_t max_seqlen_k,
741-
const c10::optional<at::Tensor>& num_contexts, // b
742-
const c10::optional<at::Tensor>& num_targets, // b
743+
const std::optional<at::Tensor>& num_contexts, // b
744+
const std::optional<at::Tensor>& num_targets, // b
743745
const int64_t target_group_size,
744746
int64_t window_size_left,
745747
int64_t window_size_right,
746748
const double alpha,
747-
const c10::optional<at::Tensor>& rab,
749+
const std::optional<at::Tensor>& rab,
748750
const bool has_drab,
749751
const bool is_delta_q,
750752
const bool deterministic) {

fbgemm_gpu/experimental/hstu/src/hstu_hopper/hstu_bwd_launch_template.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
/*
2-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
3-
* Pradeep Ramani, Tri Dao.
4-
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
5-
* Copyright (c) Meta Platforms, Inc. and affiliates.
2+
* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
63
* All rights reserved.
74
*
85
* This source code is licensed under the BSD-style license found in the
96
* LICENSE file in the root directory of this source tree.
107
*/
118

9+
/*
10+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
11+
* Pradeep Ramani, Tri Dao.
12+
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
13+
*/
14+
1215
#pragma once
1316

1417
#include "cute/tensor.hpp"
@@ -126,7 +129,7 @@ void run_hstu_bwd(Hstu_bwd_params& params, cudaStream_t stream) {
126129
params.window_size_left,
127130
params.window_size_right,
128131
params.target_group_size,
129-
1.0 / params.target_group_size,
132+
1.0f / params.target_group_size,
130133
params.alpha,
131134
params.dq_semaphore});
132135

0 commit comments

Comments
 (0)