15
15
import multiprocessing
16
16
import os
17
17
18
- # def get_gencode_flags():
19
- # import paddle
20
18
21
- # prop = paddle.device.cuda.get_device_properties()
22
- # cc = prop.major * 10 + prop.minor
23
- # return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
19
+ def get_gencode_flags ():
20
+ import paddle
21
+
22
+ prop = paddle .device .cuda .get_device_properties ()
23
+ cc = prop .major * 10 + prop .minor
24
+ return ["-gencode" , "arch=compute_{0},code=sm_{0}" .format (cc )]
24
25
25
26
26
27
def run (func ):
@@ -42,7 +43,7 @@ def setup_fast_ln():
42
43
if is_compiled_with_rocm ():
43
44
print ("The 'fasl_ln' feature is temporarily not supported on the ROCm platform !!!" )
44
45
else :
45
- # gencode_flags = get_gencode_flags()
46
+ gencode_flags = get_gencode_flags ()
46
47
change_pwd ()
47
48
setup (
48
49
name = "fast_ln" ,
@@ -66,7 +67,8 @@ def setup_fast_ln():
66
67
"--expt-relaxed-constexpr" ,
67
68
"--expt-extended-lambda" ,
68
69
"--use_fast_math" ,
69
- ],
70
+ ]
71
+ + gencode_flags ,
70
72
},
71
73
),
72
74
)
@@ -76,7 +78,7 @@ def setup_fused_ln():
76
78
from paddle .device import is_compiled_with_rocm
77
79
from paddle .utils .cpp_extension import CUDAExtension , setup
78
80
79
- # gencode_flags = get_gencode_flags()
81
+ gencode_flags = get_gencode_flags ()
80
82
change_pwd ()
81
83
if is_compiled_with_rocm ():
82
84
setup (
@@ -122,93 +124,12 @@ def setup_fused_ln():
122
124
"--expt-extended-lambda" ,
123
125
"--use_fast_math" ,
124
126
"-maxrregcount=50" ,
125
- ],
127
+ ]
128
+ + gencode_flags ,
126
129
},
127
130
),
128
131
)
129
132
130
133
131
- def setup_fused_quant_ops ():
132
- """setup_fused_fp8_ops"""
133
- from paddle .utils .cpp_extension import CUDAExtension , setup
134
-
135
- # gencode_flags = get_gencode_flags()
136
- change_pwd ()
137
- setup (
138
- name = "FusedQuantOps" ,
139
- ext_modules = CUDAExtension (
140
- sources = [
141
- "fused_quanted_ops/fused_swiglu_act_quant.cu" ,
142
- "fused_quanted_ops/fused_act_quant.cu" ,
143
- "fused_quanted_ops/fused_act_dequant.cu" ,
144
- "fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu" ,
145
- ],
146
- extra_compile_args = {
147
- "cxx" : ["-O3" , "-w" , "-Wno-abi" , "-fPIC" , "-std=c++17" ],
148
- "nvcc" : [
149
- "-O3" ,
150
- "-U__CUDA_NO_HALF_OPERATORS__" ,
151
- "-U__CUDA_NO_HALF_CONVERSIONS__" ,
152
- "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
153
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
154
- "-U__CUDA_NO_BFLOAT162_OPERATORS__" ,
155
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" ,
156
- "-DCUTE_ARCH_MMA_SM90A_ENABLE" ,
157
- "--expt-relaxed-constexpr" ,
158
- "--expt-extended-lambda" ,
159
- "--use_fast_math" ,
160
- "-lineinfo" ,
161
- "-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
162
- "-maxrregcount=50" ,
163
- "-gencode=arch=compute_90a,code=sm_90a" ,
164
- "-DNDEBUG" ,
165
- ],
166
- },
167
- ),
168
- )
169
-
170
-
171
- def setup_token_dispatcher_utils ():
172
- from paddle .utils .cpp_extension import CUDAExtension , setup
173
-
174
- change_pwd ()
175
- setup (
176
- name = "TokenDispatcherUtils" ,
177
- ext_modules = CUDAExtension (
178
- sources = [
179
- "token_dispatcher_utils/topk_to_multihot.cu" ,
180
- "token_dispatcher_utils/topk_to_multihot_grad.cu" ,
181
- "token_dispatcher_utils/tokens_unzip_and_zip.cu" ,
182
- "token_dispatcher_utils/tokens_stable_unzip.cu" ,
183
- "token_dispatcher_utils/tokens_guided_unzip.cu" ,
184
- "token_dispatcher_utils/regroup_tokens.cu" ,
185
- ],
186
- extra_compile_args = {
187
- "cxx" : ["-O3" , "-w" , "-Wno-abi" , "-fPIC" , "-std=c++17" ],
188
- "nvcc" : [
189
- "-O3" ,
190
- "-U__CUDA_NO_HALF_OPERATORS__" ,
191
- "-U__CUDA_NO_HALF_CONVERSIONS__" ,
192
- "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
193
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
194
- "-U__CUDA_NO_BFLOAT162_OPERATORS__" ,
195
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" ,
196
- "-DCUTE_ARCH_MMA_SM90A_ENABLE" ,
197
- "--expt-relaxed-constexpr" ,
198
- "--expt-extended-lambda" ,
199
- "--use_fast_math" ,
200
- "-maxrregcount=80" ,
201
- "-lineinfo" ,
202
- "-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
203
- "-gencode=arch=compute_90a,code=sm_90a" ,
204
- "-DNDEBUG" ,
205
- ],
206
- },
207
- ),
208
- )
209
-
210
-
211
- run (setup_token_dispatcher_utils )
212
- run (setup_fused_quant_ops )
213
134
run (setup_fast_ln )
214
135
run (setup_fused_ln )
0 commit comments