15
15
import multiprocessing
16
16
import os
17
17
18
+ # def get_gencode_flags():
19
+ # import paddle
18
20
19
- def get_gencode_flags ():
20
- import paddle
21
+ # prop = paddle.device.cuda.get_device_properties()
22
+ # cc = prop.major * 10 + prop.minor
23
+ # return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
21
24
22
- prop = paddle .device .cuda .get_device_properties ()
23
- cc = prop .major * 10 + prop .minor
24
- return ["-gencode" , "arch=compute_{0},code=sm_{0}" .format (cc )]
25
25
26
26
def run (func ):
27
27
p = multiprocessing .Process (target = func )
@@ -36,13 +36,13 @@ def change_pwd():
36
36
37
37
38
38
def setup_fast_ln ():
39
- from paddle .utils .cpp_extension import CUDAExtension , setup
40
39
from paddle .device import is_compiled_with_rocm
40
+ from paddle .utils .cpp_extension import CUDAExtension , setup
41
41
42
- if ( is_compiled_with_rocm () ):
42
+ if is_compiled_with_rocm ():
43
43
print ("The 'fasl_ln' feature is temporarily not supported on the ROCm platform !!!" )
44
44
else :
45
- gencode_flags = get_gencode_flags ()
45
+ # gencode_flags = get_gencode_flags()
46
46
change_pwd ()
47
47
setup (
48
48
name = "fast_ln" ,
@@ -66,20 +66,19 @@ def setup_fast_ln():
66
66
"--expt-relaxed-constexpr" ,
67
67
"--expt-extended-lambda" ,
68
68
"--use_fast_math" ,
69
- ]
70
- + gencode_flags ,
69
+ ],
71
70
},
72
71
),
73
72
)
74
73
75
74
76
75
def setup_fused_ln ():
77
- from paddle .utils .cpp_extension import CUDAExtension , setup
78
76
from paddle .device import is_compiled_with_rocm
77
+ from paddle .utils .cpp_extension import CUDAExtension , setup
79
78
80
- gencode_flags = get_gencode_flags ()
79
+ # gencode_flags = get_gencode_flags()
81
80
change_pwd ()
82
- if ( is_compiled_with_rocm () ):
81
+ if is_compiled_with_rocm ():
83
82
setup (
84
83
name = "fused_ln" ,
85
84
ext_modules = CUDAExtension (
@@ -97,7 +96,7 @@ def setup_fused_ln():
97
96
"-U__CUDA_NO_BFLOAT162_OPERATORS__" ,
98
97
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__" ,
99
98
"-DPADDLE_WITH_HIP" ,
100
- ]
99
+ ],
101
100
},
102
101
),
103
102
)
@@ -123,17 +122,17 @@ def setup_fused_ln():
123
122
"--expt-extended-lambda" ,
124
123
"--use_fast_math" ,
125
124
"-maxrregcount=50" ,
126
- ]
127
- + gencode_flags ,
125
+ ],
128
126
},
129
127
),
130
128
)
131
129
130
+
132
131
def setup_fused_quant_ops ():
133
132
"""setup_fused_fp8_ops"""
134
133
from paddle .utils .cpp_extension import CUDAExtension , setup
135
134
136
- gencode_flags = get_gencode_flags ()
135
+ # gencode_flags = get_gencode_flags()
137
136
change_pwd ()
138
137
setup (
139
138
name = "FusedQuantOps" ,
@@ -145,13 +144,7 @@ def setup_fused_quant_ops():
145
144
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu" ,
146
145
],
147
146
extra_compile_args = {
148
- "cxx" : [
149
- "-O3" ,
150
- "-w" ,
151
- "-Wno-abi" ,
152
- "-fPIC" ,
153
- "-std=c++17"
154
- ],
147
+ "cxx" : ["-O3" , "-w" , "-Wno-abi" , "-fPIC" , "-std=c++17" ],
155
148
"nvcc" : [
156
149
"-O3" ,
157
150
"-U__CUDA_NO_HALF_OPERATORS__" ,
@@ -168,12 +161,13 @@ def setup_fused_quant_ops():
168
161
"-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
169
162
"-maxrregcount=50" ,
170
163
"-gencode=arch=compute_90a,code=sm_90a" ,
171
- "-DNDEBUG"
172
- ] + gencode_flags ,
164
+ "-DNDEBUG" ,
165
+ ],
173
166
},
174
167
),
175
168
)
176
169
170
+
177
171
def setup_token_dispatcher_utils ():
178
172
from paddle .utils .cpp_extension import CUDAExtension , setup
179
173
@@ -190,35 +184,30 @@ def setup_token_dispatcher_utils():
190
184
"token_dispatcher_utils/regroup_tokens.cu" ,
191
185
],
192
186
extra_compile_args = {
193
- "cxx" : [
187
+ "cxx" : ["-O3" , "-w" , "-Wno-abi" , "-fPIC" , "-std=c++17" ],
188
+ "nvcc" : [
194
189
"-O3" ,
195
- "-w" ,
196
- "-Wno-abi" ,
197
- "-fPIC" ,
198
- "-std=c++17"
190
+ "-U__CUDA_NO_HALF_OPERATORS__" ,
191
+ "-U__CUDA_NO_HALF_CONVERSIONS__" ,
192
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
193
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
194
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__" ,
195
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" ,
196
+ "-DCUTE_ARCH_MMA_SM90A_ENABLE" ,
197
+ "--expt-relaxed-constexpr" ,
198
+ "--expt-extended-lambda" ,
199
+ "--use_fast_math" ,
200
+ "-maxrregcount=80" ,
201
+ "-lineinfo" ,
202
+ "-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
203
+ "-gencode=arch=compute_90a,code=sm_90a" ,
204
+ "-DNDEBUG" ,
199
205
],
200
- "nvcc" : [
201
- "-O3" ,
202
- "-U__CUDA_NO_HALF_OPERATORS__" ,
203
- "-U__CUDA_NO_HALF_CONVERSIONS__" ,
204
- "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
205
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
206
- "-U__CUDA_NO_BFLOAT162_OPERATORS__" ,
207
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" ,
208
- "-DCUTE_ARCH_MMA_SM90A_ENABLE" ,
209
- "--expt-relaxed-constexpr" ,
210
- "--expt-extended-lambda" ,
211
- "--use_fast_math" ,
212
- "-maxrregcount=80" ,
213
- "-lineinfo" ,
214
- "-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
215
- "-gencode=arch=compute_90a,code=sm_90a" ,
216
- "-DNDEBUG"
217
- ]
218
206
},
219
207
),
220
208
)
221
209
210
+
222
211
run (setup_token_dispatcher_utils )
223
212
run (setup_fused_quant_ops )
224
213
run (setup_fast_ln )
0 commit comments