File tree Expand file tree Collapse file tree 2 files changed +264
-257
lines changed
torchbenchmark/operators/gemm Expand file tree Collapse file tree 2 files changed +264
-257
lines changed Original file line number Diff line number Diff line change 8
8
import triton
9
9
import triton .language as tl
10
10
11
+ from .triton_matmul_configs import configs
12
+
11
13
12
14
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
13
15
# - A list of `triton.Config` objects that define different configurations of
14
16
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
15
17
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
16
18
# provided configs
17
19
@triton .autotune (
18
- configs = (
19
- [
20
- triton .Config (
21
- {
22
- "BLOCK_SIZE_M" : 64 ,
23
- "BLOCK_SIZE_N" : 16 ,
24
- "BLOCK_SIZE_K" : 256 ,
25
- "GROUP_SIZE_M" : 8 ,
26
- },
27
- num_stages = 2 ,
28
- num_warps = 4 ,
29
- ),
30
- triton .Config (
31
- {
32
- "BLOCK_SIZE_M" : 32 ,
33
- "BLOCK_SIZE_N" : 16 ,
34
- "BLOCK_SIZE_K" : 128 ,
35
- "GROUP_SIZE_M" : 8 ,
36
- },
37
- num_stages = 2 ,
38
- num_warps = 2 ,
39
- ),
40
- triton .Config (
41
- {
42
- "BLOCK_SIZE_M" : 128 ,
43
- "BLOCK_SIZE_N" : 256 ,
44
- "BLOCK_SIZE_K" : 32 ,
45
- "GROUP_SIZE_M" : 8 ,
46
- },
47
- num_stages = 2 ,
48
- num_warps = 8 ,
49
- ),
50
- triton .Config (
51
- {
52
- "BLOCK_SIZE_M" : 256 ,
53
- "BLOCK_SIZE_N" : 256 ,
54
- "BLOCK_SIZE_K" : 32 ,
55
- "GROUP_SIZE_M" : 8 ,
56
- },
57
- num_stages = 2 ,
58
- num_warps = 8 ,
59
- ),
60
- triton .Config (
61
- {
62
- "BLOCK_SIZE_M" : 128 ,
63
- "BLOCK_SIZE_N" : 64 ,
64
- "BLOCK_SIZE_K" : 128 ,
65
- "GROUP_SIZE_M" : 8 ,
66
- },
67
- num_stages = 2 ,
68
- num_warps = 8 ,
69
- ),
70
- triton .Config (
71
- {
72
- "BLOCK_SIZE_M" : 256 ,
73
- "BLOCK_SIZE_N" : 128 ,
74
- "BLOCK_SIZE_K" : 32 ,
75
- "GROUP_SIZE_M" : 8 ,
76
- },
77
- num_stages = 2 ,
78
- num_warps = 8 ,
79
- ),
80
- triton .Config (
81
- {
82
- "BLOCK_SIZE_M" : 32 ,
83
- "BLOCK_SIZE_N" : 64 ,
84
- "BLOCK_SIZE_K" : 128 ,
85
- "GROUP_SIZE_M" : 8 ,
86
- },
87
- num_stages = 2 ,
88
- num_warps = 2 ,
89
- ),
90
- triton .Config (
91
- {
92
- "BLOCK_SIZE_M" : 128 ,
93
- "BLOCK_SIZE_N" : 128 ,
94
- "BLOCK_SIZE_K" : 128 ,
95
- "GROUP_SIZE_M" : 8 ,
96
- },
97
- num_stages = 2 ,
98
- num_warps = 8 ,
99
- ),
100
- triton .Config (
101
- {
102
- "BLOCK_SIZE_M" : 128 ,
103
- "BLOCK_SIZE_N" : 128 ,
104
- "BLOCK_SIZE_K" : 64 ,
105
- "GROUP_SIZE_M" : 8 ,
106
- },
107
- num_stages = 2 ,
108
- num_warps = 2 ,
109
- ),
110
- triton .Config (
111
- {
112
- "BLOCK_SIZE_M" : 128 ,
113
- "BLOCK_SIZE_N" : 128 ,
114
- "BLOCK_SIZE_K" : 64 ,
115
- "GROUP_SIZE_M" : 8 ,
116
- },
117
- num_stages = 2 ,
118
- num_warps = 8 ,
119
- ),
120
- triton .Config (
121
- {
122
- "BLOCK_SIZE_M" : 128 ,
123
- "BLOCK_SIZE_N" : 128 ,
124
- "BLOCK_SIZE_K" : 64 ,
125
- "GROUP_SIZE_M" : 8 ,
126
- },
127
- num_stages = 1 ,
128
- num_warps = 4 ,
129
- ),
130
- triton .Config (
131
- {
132
- "BLOCK_SIZE_M" : 64 ,
133
- "BLOCK_SIZE_N" : 64 ,
134
- "BLOCK_SIZE_K" : 128 ,
135
- "GROUP_SIZE_M" : 8 ,
136
- },
137
- num_stages = 2 ,
138
- num_warps = 4 ,
139
- ),
140
- triton .Config (
141
- {
142
- "BLOCK_SIZE_M" : 64 ,
143
- "BLOCK_SIZE_N" : 64 ,
144
- "BLOCK_SIZE_K" : 64 ,
145
- "GROUP_SIZE_M" : 8 ,
146
- },
147
- num_stages = 2 ,
148
- num_warps = 4 ,
149
- ),
150
- triton .Config (
151
- {
152
- "BLOCK_SIZE_M" : 16 ,
153
- "BLOCK_SIZE_N" : 16 ,
154
- "BLOCK_SIZE_K" : 256 ,
155
- "GROUP_SIZE_M" : 8 ,
156
- },
157
- num_stages = 2 ,
158
- num_warps = 4 ,
159
- ),
160
- triton .Config (
161
- {
162
- "BLOCK_SIZE_M" : 128 ,
163
- "BLOCK_SIZE_N" : 128 ,
164
- "BLOCK_SIZE_K" : 64 ,
165
- "GROUP_SIZE_M" : 8 ,
166
- },
167
- num_stages = 1 ,
168
- num_warps = 8 ,
169
- ),
170
- triton .Config (
171
- {
172
- "BLOCK_SIZE_M" : 64 ,
173
- "BLOCK_SIZE_N" : 128 ,
174
- "BLOCK_SIZE_K" : 64 ,
175
- "GROUP_SIZE_M" : 8 ,
176
- },
177
- num_stages = 2 ,
178
- num_warps = 8 ,
179
- ),
180
- triton .Config (
181
- {
182
- "BLOCK_SIZE_M" : 16 ,
183
- "BLOCK_SIZE_N" : 16 ,
184
- "BLOCK_SIZE_K" : 128 ,
185
- "GROUP_SIZE_M" : 8 ,
186
- },
187
- num_stages = 2 ,
188
- num_warps = 4 ,
189
- ),
190
- ]
191
- if torch .version .hip is not None
192
- else [
193
- triton .Config (
194
- {
195
- "BLOCK_SIZE_M" : 128 ,
196
- "BLOCK_SIZE_N" : 256 ,
197
- "BLOCK_SIZE_K" : 64 ,
198
- "GROUP_SIZE_M" : 8 ,
199
- },
200
- num_stages = 3 ,
201
- num_warps = 8 ,
202
- ),
203
- triton .Config (
204
- {
205
- "BLOCK_SIZE_M" : 64 ,
206
- "BLOCK_SIZE_N" : 256 ,
207
- "BLOCK_SIZE_K" : 32 ,
208
- "GROUP_SIZE_M" : 8 ,
209
- },
210
- num_stages = 4 ,
211
- num_warps = 4 ,
212
- ),
213
- triton .Config (
214
- {
215
- "BLOCK_SIZE_M" : 128 ,
216
- "BLOCK_SIZE_N" : 128 ,
217
- "BLOCK_SIZE_K" : 32 ,
218
- "GROUP_SIZE_M" : 8 ,
219
- },
220
- num_stages = 4 ,
221
- num_warps = 4 ,
222
- ),
223
- triton .Config (
224
- {
225
- "BLOCK_SIZE_M" : 128 ,
226
- "BLOCK_SIZE_N" : 64 ,
227
- "BLOCK_SIZE_K" : 32 ,
228
- "GROUP_SIZE_M" : 8 ,
229
- },
230
- num_stages = 4 ,
231
- num_warps = 4 ,
232
- ),
233
- triton .Config (
234
- {
235
- "BLOCK_SIZE_M" : 64 ,
236
- "BLOCK_SIZE_N" : 128 ,
237
- "BLOCK_SIZE_K" : 32 ,
238
- "GROUP_SIZE_M" : 8 ,
239
- },
240
- num_stages = 4 ,
241
- num_warps = 4 ,
242
- ),
243
- triton .Config (
244
- {
245
- "BLOCK_SIZE_M" : 128 ,
246
- "BLOCK_SIZE_N" : 32 ,
247
- "BLOCK_SIZE_K" : 32 ,
248
- "GROUP_SIZE_M" : 8 ,
249
- },
250
- num_stages = 4 ,
251
- num_warps = 4 ,
252
- ),
253
- triton .Config (
254
- {
255
- "BLOCK_SIZE_M" : 64 ,
256
- "BLOCK_SIZE_N" : 32 ,
257
- "BLOCK_SIZE_K" : 32 ,
258
- "GROUP_SIZE_M" : 8 ,
259
- },
260
- num_stages = 5 ,
261
- num_warps = 2 ,
262
- ),
263
- triton .Config (
264
- {
265
- "BLOCK_SIZE_M" : 32 ,
266
- "BLOCK_SIZE_N" : 64 ,
267
- "BLOCK_SIZE_K" : 32 ,
268
- "GROUP_SIZE_M" : 8 ,
269
- },
270
- num_stages = 5 ,
271
- num_warps = 2 ,
272
- ),
273
- ]
274
- ),
20
+ configs = configs ,
275
21
key = ["M" , "N" , "K" ],
276
22
)
277
23
@triton .jit
You can’t perform that action at this time.
0 commit comments