Skip to content

Commit f4bb9e1

Browse files
committed
fix
1 parent 10ee376 commit f4bb9e1

File tree

2 files changed

+134
-91
lines changed

2 files changed

+134
-91
lines changed

slm/model_zoo/gpt-3/external_ops/setup.py

Lines changed: 12 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import multiprocessing
1616
import os
1717

18-
# def get_gencode_flags():
19-
# import paddle
2018

21-
# prop = paddle.device.cuda.get_device_properties()
22-
# cc = prop.major * 10 + prop.minor
23-
# return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
19+
def get_gencode_flags():
20+
import paddle
21+
22+
prop = paddle.device.cuda.get_device_properties()
23+
cc = prop.major * 10 + prop.minor
24+
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
2425

2526

2627
def run(func):
@@ -42,7 +43,7 @@ def setup_fast_ln():
4243
if is_compiled_with_rocm():
4344
print("The 'fasl_ln' feature is temporarily not supported on the ROCm platform !!!")
4445
else:
45-
# gencode_flags = get_gencode_flags()
46+
gencode_flags = get_gencode_flags()
4647
change_pwd()
4748
setup(
4849
name="fast_ln",
@@ -66,7 +67,8 @@ def setup_fast_ln():
6667
"--expt-relaxed-constexpr",
6768
"--expt-extended-lambda",
6869
"--use_fast_math",
69-
],
70+
]
71+
+ gencode_flags,
7072
},
7173
),
7274
)
@@ -76,7 +78,7 @@ def setup_fused_ln():
7678
from paddle.device import is_compiled_with_rocm
7779
from paddle.utils.cpp_extension import CUDAExtension, setup
7880

79-
# gencode_flags = get_gencode_flags()
81+
gencode_flags = get_gencode_flags()
8082
change_pwd()
8183
if is_compiled_with_rocm():
8284
setup(
@@ -122,93 +124,12 @@ def setup_fused_ln():
122124
"--expt-extended-lambda",
123125
"--use_fast_math",
124126
"-maxrregcount=50",
125-
],
127+
]
128+
+ gencode_flags,
126129
},
127130
),
128131
)
129132

130133

131-
def setup_fused_quant_ops():
132-
"""setup_fused_fp8_ops"""
133-
from paddle.utils.cpp_extension import CUDAExtension, setup
134-
135-
# gencode_flags = get_gencode_flags()
136-
change_pwd()
137-
setup(
138-
name="FusedQuantOps",
139-
ext_modules=CUDAExtension(
140-
sources=[
141-
"fused_quanted_ops/fused_swiglu_act_quant.cu",
142-
"fused_quanted_ops/fused_act_quant.cu",
143-
"fused_quanted_ops/fused_act_dequant.cu",
144-
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
145-
],
146-
extra_compile_args={
147-
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
148-
"nvcc": [
149-
"-O3",
150-
"-U__CUDA_NO_HALF_OPERATORS__",
151-
"-U__CUDA_NO_HALF_CONVERSIONS__",
152-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
153-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
154-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
155-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
156-
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
157-
"--expt-relaxed-constexpr",
158-
"--expt-extended-lambda",
159-
"--use_fast_math",
160-
"-lineinfo",
161-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
162-
"-maxrregcount=50",
163-
"-gencode=arch=compute_90a,code=sm_90a",
164-
"-DNDEBUG",
165-
],
166-
},
167-
),
168-
)
169-
170-
171-
def setup_token_dispatcher_utils():
172-
from paddle.utils.cpp_extension import CUDAExtension, setup
173-
174-
change_pwd()
175-
setup(
176-
name="TokenDispatcherUtils",
177-
ext_modules=CUDAExtension(
178-
sources=[
179-
"token_dispatcher_utils/topk_to_multihot.cu",
180-
"token_dispatcher_utils/topk_to_multihot_grad.cu",
181-
"token_dispatcher_utils/tokens_unzip_and_zip.cu",
182-
"token_dispatcher_utils/tokens_stable_unzip.cu",
183-
"token_dispatcher_utils/tokens_guided_unzip.cu",
184-
"token_dispatcher_utils/regroup_tokens.cu",
185-
],
186-
extra_compile_args={
187-
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
188-
"nvcc": [
189-
"-O3",
190-
"-U__CUDA_NO_HALF_OPERATORS__",
191-
"-U__CUDA_NO_HALF_CONVERSIONS__",
192-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
193-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
194-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
195-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
196-
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
197-
"--expt-relaxed-constexpr",
198-
"--expt-extended-lambda",
199-
"--use_fast_math",
200-
"-maxrregcount=80",
201-
"-lineinfo",
202-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
203-
"-gencode=arch=compute_90a,code=sm_90a",
204-
"-DNDEBUG",
205-
],
206-
},
207-
),
208-
)
209-
210-
211-
run(setup_token_dispatcher_utils)
212-
run(setup_fused_quant_ops)
213134
run(setup_fast_ln)
214135
run(setup_fused_ln)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import multiprocessing
16+
import os
17+
18+
19+
def get_gencode_flags():
20+
import paddle
21+
22+
prop = paddle.device.cuda.get_device_properties()
23+
cc = prop.major * 10 + prop.minor
24+
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
25+
26+
27+
def run(func):
28+
p = multiprocessing.Process(target=func)
29+
p.start()
30+
p.join()
31+
32+
33+
def change_pwd():
34+
path = os.path.dirname(__file__)
35+
if path:
36+
os.chdir(path)
37+
38+
39+
def setup_fused_quant_ops():
40+
"""setup_fused_fp8_ops"""
41+
from paddle.utils.cpp_extension import CUDAExtension, setup
42+
43+
gencode_flags = get_gencode_flags()
44+
change_pwd()
45+
setup(
46+
name="FusedQuantOps",
47+
ext_modules=CUDAExtension(
48+
sources=[
49+
"fused_quanted_ops/fused_swiglu_act_quant.cu",
50+
"fused_quanted_ops/fused_act_quant.cu",
51+
"fused_quanted_ops/fused_act_dequant.cu",
52+
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
53+
"fused_quanted_ops/fused_spaq.cu",
54+
],
55+
extra_compile_args={
56+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
57+
"nvcc": [
58+
"-O3",
59+
"-U__CUDA_NO_HALF_OPERATORS__",
60+
"-U__CUDA_NO_HALF_CONVERSIONS__",
61+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
62+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
63+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
64+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
65+
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
66+
"--expt-relaxed-constexpr",
67+
"--expt-extended-lambda",
68+
"--use_fast_math",
69+
"-lineinfo",
70+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
71+
"-maxrregcount=50",
72+
"-gencode=arch=compute_90a,code=sm_90a",
73+
"-DNDEBUG",
74+
]
75+
+ gencode_flags,
76+
},
77+
),
78+
)
79+
80+
81+
def setup_token_dispatcher_utils():
82+
from paddle.utils.cpp_extension import CUDAExtension, setup
83+
84+
change_pwd()
85+
setup(
86+
name="TokenDispatcherUtils",
87+
ext_modules=CUDAExtension(
88+
sources=[
89+
"token_dispatcher_utils/topk_to_multihot.cu",
90+
"token_dispatcher_utils/topk_to_multihot_grad.cu",
91+
"token_dispatcher_utils/tokens_unzip_and_zip.cu",
92+
"token_dispatcher_utils/tokens_stable_unzip.cu",
93+
"token_dispatcher_utils/tokens_guided_unzip.cu",
94+
"token_dispatcher_utils/regroup_tokens.cu",
95+
],
96+
extra_compile_args={
97+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
98+
"nvcc": [
99+
"-O3",
100+
"-U__CUDA_NO_HALF_OPERATORS__",
101+
"-U__CUDA_NO_HALF_CONVERSIONS__",
102+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
103+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
104+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
105+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
106+
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
107+
"--expt-relaxed-constexpr",
108+
"--expt-extended-lambda",
109+
"--use_fast_math",
110+
"-maxrregcount=80",
111+
"-lineinfo",
112+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
113+
"-gencode=arch=compute_90a,code=sm_90a",
114+
"-DNDEBUG",
115+
],
116+
},
117+
),
118+
)
119+
120+
121+
run(setup_token_dispatcher_utils)
122+
run(setup_fused_quant_ops)

0 commit comments

Comments
 (0)