Skip to content

Commit 670cbd9

Browse files
authored
Refine setup.py (#10577)
* fix * fix
1 parent ae560af commit 670cbd9

File tree

2 files changed

+112
-92
lines changed

2 files changed

+112
-92
lines changed

slm/model_zoo/gpt-3/external_ops/setup.py

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_gencode_flags():
2323
cc = prop.major * 10 + prop.minor
2424
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
2525

26+
2627
def run(func):
2728
p = multiprocessing.Process(target=func)
2829
p.start()
@@ -129,98 +130,6 @@ def setup_fused_ln():
129130
),
130131
)
131132

132-
def setup_fused_quant_ops():
133-
"""setup_fused_fp8_ops"""
134-
from paddle.utils.cpp_extension import CUDAExtension, setup
135133

136-
gencode_flags = get_gencode_flags()
137-
change_pwd()
138-
setup(
139-
name="FusedQuantOps",
140-
ext_modules=CUDAExtension(
141-
sources=[
142-
"fused_quanted_ops/fused_swiglu_act_quant.cu",
143-
"fused_quanted_ops/fused_act_quant.cu",
144-
"fused_quanted_ops/fused_act_dequant.cu",
145-
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
146-
"fused_quanted_ops/fused_spaq.cu",
147-
],
148-
extra_compile_args={
149-
"cxx": [
150-
"-O3",
151-
"-w",
152-
"-Wno-abi",
153-
"-fPIC",
154-
"-std=c++17"
155-
],
156-
"nvcc": [
157-
"-O3",
158-
"-U__CUDA_NO_HALF_OPERATORS__",
159-
"-U__CUDA_NO_HALF_CONVERSIONS__",
160-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
161-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
162-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
163-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
164-
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
165-
"--expt-relaxed-constexpr",
166-
"--expt-extended-lambda",
167-
"--use_fast_math",
168-
"-lineinfo",
169-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
170-
"-maxrregcount=50",
171-
"-gencode=arch=compute_90a,code=sm_90a",
172-
"-DNDEBUG"
173-
] + gencode_flags,
174-
},
175-
),
176-
)
177-
178-
def setup_token_dispatcher_utils():
179-
from paddle.utils.cpp_extension import CUDAExtension, setup
180-
181-
change_pwd()
182-
setup(
183-
name="TokenDispatcherUtils",
184-
ext_modules=CUDAExtension(
185-
sources=[
186-
"token_dispatcher_utils/topk_to_multihot.cu",
187-
"token_dispatcher_utils/topk_to_multihot_grad.cu",
188-
"token_dispatcher_utils/tokens_unzip_and_zip.cu",
189-
"token_dispatcher_utils/tokens_stable_unzip.cu",
190-
"token_dispatcher_utils/tokens_guided_unzip.cu",
191-
"token_dispatcher_utils/regroup_tokens.cu",
192-
],
193-
extra_compile_args={
194-
"cxx": [
195-
"-O3",
196-
"-w",
197-
"-Wno-abi",
198-
"-fPIC",
199-
"-std=c++17"
200-
],
201-
"nvcc": [
202-
"-O3",
203-
"-U__CUDA_NO_HALF_OPERATORS__",
204-
"-U__CUDA_NO_HALF_CONVERSIONS__",
205-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
206-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
207-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
208-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
209-
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
210-
"--expt-relaxed-constexpr",
211-
"--expt-extended-lambda",
212-
"--use_fast_math",
213-
"-maxrregcount=80",
214-
"-lineinfo",
215-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
216-
"-gencode=arch=compute_90a,code=sm_90a",
217-
"-DNDEBUG"
218-
]
219-
},
220-
),
221-
)
222-
223-
run(setup_token_dispatcher_utils)
224-
run(setup_fused_quant_ops)
225134
run(setup_fast_ln)
226135
run(setup_fused_ln)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import multiprocessing
16+
import os
17+
18+
def run(func):
19+
p = multiprocessing.Process(target=func)
20+
p.start()
21+
p.join()
22+
23+
24+
def change_pwd():
25+
path = os.path.dirname(__file__)
26+
if path:
27+
os.chdir(path)
28+
29+
30+
def setup_fused_quant_ops():
31+
"""setup_fused_fp8_ops"""
32+
from paddle.utils.cpp_extension import CUDAExtension, setup
33+
34+
change_pwd()
35+
setup(
36+
name="FusedQuantOps",
37+
ext_modules=CUDAExtension(
38+
sources=[
39+
"fused_quanted_ops/fused_swiglu_act_quant.cu",
40+
"fused_quanted_ops/fused_act_quant.cu",
41+
"fused_quanted_ops/fused_act_dequant.cu",
42+
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
43+
"fused_quanted_ops/fused_spaq.cu",
44+
],
45+
extra_compile_args={
46+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
47+
"nvcc": [
48+
"-O3",
49+
"-U__CUDA_NO_HALF_OPERATORS__",
50+
"-U__CUDA_NO_HALF_CONVERSIONS__",
51+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
52+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
53+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
54+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
55+
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
56+
"--expt-relaxed-constexpr",
57+
"--expt-extended-lambda",
58+
"--use_fast_math",
59+
"-lineinfo",
60+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
61+
"-maxrregcount=50",
62+
"-gencode=arch=compute_90a,code=sm_90a",
63+
"-DNDEBUG",
64+
]
65+
},
66+
),
67+
)
68+
69+
70+
def setup_token_dispatcher_utils():
71+
from paddle.utils.cpp_extension import CUDAExtension, setup
72+
73+
change_pwd()
74+
setup(
75+
name="TokenDispatcherUtils",
76+
ext_modules=CUDAExtension(
77+
sources=[
78+
"token_dispatcher_utils/topk_to_multihot.cu",
79+
"token_dispatcher_utils/topk_to_multihot_grad.cu",
80+
"token_dispatcher_utils/tokens_unzip_and_zip.cu",
81+
"token_dispatcher_utils/tokens_stable_unzip.cu",
82+
"token_dispatcher_utils/tokens_guided_unzip.cu",
83+
"token_dispatcher_utils/regroup_tokens.cu",
84+
],
85+
extra_compile_args={
86+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
87+
"nvcc": [
88+
"-O3",
89+
"-U__CUDA_NO_HALF_OPERATORS__",
90+
"-U__CUDA_NO_HALF_CONVERSIONS__",
91+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
92+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
93+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
94+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
95+
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
96+
"--expt-relaxed-constexpr",
97+
"--expt-extended-lambda",
98+
"--use_fast_math",
99+
"-maxrregcount=80",
100+
"-lineinfo",
101+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
102+
"-gencode=arch=compute_90a,code=sm_90a",
103+
"-DNDEBUG",
104+
],
105+
},
106+
),
107+
)
108+
109+
110+
run(setup_token_dispatcher_utils)
111+
run(setup_fused_quant_ops)

0 commit comments

Comments
 (0)