Skip to content

Commit 10ee376

Browse files
committed
fix
1 parent ab69953 commit 10ee376

File tree

1 file changed

+38
-49
lines changed
  • slm/model_zoo/gpt-3/external_ops

1 file changed

+38
-49
lines changed

slm/model_zoo/gpt-3/external_ops/setup.py

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import multiprocessing
1616
import os
1717

18+
# def get_gencode_flags():
19+
# import paddle
1820

19-
def get_gencode_flags():
20-
import paddle
21+
# prop = paddle.device.cuda.get_device_properties()
22+
# cc = prop.major * 10 + prop.minor
23+
# return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
2124

22-
prop = paddle.device.cuda.get_device_properties()
23-
cc = prop.major * 10 + prop.minor
24-
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
2525

2626
def run(func):
2727
p = multiprocessing.Process(target=func)
@@ -36,13 +36,13 @@ def change_pwd():
3636

3737

3838
def setup_fast_ln():
39-
from paddle.utils.cpp_extension import CUDAExtension, setup
4039
from paddle.device import is_compiled_with_rocm
40+
from paddle.utils.cpp_extension import CUDAExtension, setup
4141

42-
if(is_compiled_with_rocm()):
42+
if is_compiled_with_rocm():
4343
print("The 'fasl_ln' feature is temporarily not supported on the ROCm platform !!!")
4444
else:
45-
gencode_flags = get_gencode_flags()
45+
# gencode_flags = get_gencode_flags()
4646
change_pwd()
4747
setup(
4848
name="fast_ln",
@@ -66,20 +66,19 @@ def setup_fast_ln():
6666
"--expt-relaxed-constexpr",
6767
"--expt-extended-lambda",
6868
"--use_fast_math",
69-
]
70-
+ gencode_flags,
69+
],
7170
},
7271
),
7372
)
7473

7574

7675
def setup_fused_ln():
77-
from paddle.utils.cpp_extension import CUDAExtension, setup
7876
from paddle.device import is_compiled_with_rocm
77+
from paddle.utils.cpp_extension import CUDAExtension, setup
7978

80-
gencode_flags = get_gencode_flags()
79+
# gencode_flags = get_gencode_flags()
8180
change_pwd()
82-
if(is_compiled_with_rocm()):
81+
if is_compiled_with_rocm():
8382
setup(
8483
name="fused_ln",
8584
ext_modules=CUDAExtension(
@@ -97,7 +96,7 @@ def setup_fused_ln():
9796
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
9897
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
9998
"-DPADDLE_WITH_HIP",
100-
]
99+
],
101100
},
102101
),
103102
)
@@ -123,17 +122,17 @@ def setup_fused_ln():
123122
"--expt-extended-lambda",
124123
"--use_fast_math",
125124
"-maxrregcount=50",
126-
]
127-
+ gencode_flags,
125+
],
128126
},
129127
),
130128
)
131129

130+
132131
def setup_fused_quant_ops():
133132
"""setup_fused_fp8_ops"""
134133
from paddle.utils.cpp_extension import CUDAExtension, setup
135134

136-
gencode_flags = get_gencode_flags()
135+
# gencode_flags = get_gencode_flags()
137136
change_pwd()
138137
setup(
139138
name="FusedQuantOps",
@@ -145,13 +144,7 @@ def setup_fused_quant_ops():
145144
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
146145
],
147146
extra_compile_args={
148-
"cxx": [
149-
"-O3",
150-
"-w",
151-
"-Wno-abi",
152-
"-fPIC",
153-
"-std=c++17"
154-
],
147+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
155148
"nvcc": [
156149
"-O3",
157150
"-U__CUDA_NO_HALF_OPERATORS__",
@@ -168,12 +161,13 @@ def setup_fused_quant_ops():
168161
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
169162
"-maxrregcount=50",
170163
"-gencode=arch=compute_90a,code=sm_90a",
171-
"-DNDEBUG"
172-
] + gencode_flags,
164+
"-DNDEBUG",
165+
],
173166
},
174167
),
175168
)
176169

170+
177171
def setup_token_dispatcher_utils():
178172
from paddle.utils.cpp_extension import CUDAExtension, setup
179173

@@ -190,35 +184,30 @@ def setup_token_dispatcher_utils():
190184
"token_dispatcher_utils/regroup_tokens.cu",
191185
],
192186
extra_compile_args={
193-
"cxx": [
187+
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
188+
"nvcc": [
194189
"-O3",
195-
"-w",
196-
"-Wno-abi",
197-
"-fPIC",
198-
"-std=c++17"
190+
"-U__CUDA_NO_HALF_OPERATORS__",
191+
"-U__CUDA_NO_HALF_CONVERSIONS__",
192+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
193+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
194+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
195+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
196+
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
197+
"--expt-relaxed-constexpr",
198+
"--expt-extended-lambda",
199+
"--use_fast_math",
200+
"-maxrregcount=80",
201+
"-lineinfo",
202+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
203+
"-gencode=arch=compute_90a,code=sm_90a",
204+
"-DNDEBUG",
199205
],
200-
"nvcc": [
201-
"-O3",
202-
"-U__CUDA_NO_HALF_OPERATORS__",
203-
"-U__CUDA_NO_HALF_CONVERSIONS__",
204-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
205-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
206-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
207-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
208-
"-DCUTE_ARCH_MMA_SM90A_ENABLE",
209-
"--expt-relaxed-constexpr",
210-
"--expt-extended-lambda",
211-
"--use_fast_math",
212-
"-maxrregcount=80",
213-
"-lineinfo",
214-
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
215-
"-gencode=arch=compute_90a,code=sm_90a",
216-
"-DNDEBUG"
217-
]
218206
},
219207
),
220208
)
221209

210+
222211
run(setup_token_dispatcher_utils)
223212
run(setup_fused_quant_ops)
224213
run(setup_fast_ln)

0 commit comments

Comments
 (0)