File tree Expand file tree Collapse file tree 1 file changed +257
-82
lines changed
torchbenchmark/operators/gemm Expand file tree Collapse file tree 1 file changed +257
-82
lines changed Original file line number Diff line number Diff line change 15
15
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
16
16
# provided configs
17
17
@triton .autotune (
18
- configs = [
19
- triton .Config (
20
- {
21
- "BLOCK_SIZE_M" : 128 ,
22
- "BLOCK_SIZE_N" : 256 ,
23
- "BLOCK_SIZE_K" : 64 ,
24
- "GROUP_SIZE_M" : 8 ,
25
- },
26
- num_stages = 3 ,
27
- num_warps = 8 ,
28
- ),
29
- triton .Config (
30
- {
31
- "BLOCK_SIZE_M" : 64 ,
32
- "BLOCK_SIZE_N" : 256 ,
33
- "BLOCK_SIZE_K" : 32 ,
34
- "GROUP_SIZE_M" : 8 ,
35
- },
36
- num_stages = 4 ,
37
- num_warps = 4 ,
38
- ),
39
- triton .Config (
40
- {
41
- "BLOCK_SIZE_M" : 128 ,
42
- "BLOCK_SIZE_N" : 128 ,
43
- "BLOCK_SIZE_K" : 32 ,
44
- "GROUP_SIZE_M" : 8 ,
45
- },
46
- num_stages = 4 ,
47
- num_warps = 4 ,
48
- ),
49
- triton .Config (
50
- {
51
- "BLOCK_SIZE_M" : 128 ,
52
- "BLOCK_SIZE_N" : 64 ,
53
- "BLOCK_SIZE_K" : 32 ,
54
- "GROUP_SIZE_M" : 8 ,
55
- },
56
- num_stages = 4 ,
57
- num_warps = 4 ,
58
- ),
59
- triton .Config (
60
- {
61
- "BLOCK_SIZE_M" : 64 ,
62
- "BLOCK_SIZE_N" : 128 ,
63
- "BLOCK_SIZE_K" : 32 ,
64
- "GROUP_SIZE_M" : 8 ,
65
- },
66
- num_stages = 4 ,
67
- num_warps = 4 ,
68
- ),
69
- triton .Config (
70
- {
71
- "BLOCK_SIZE_M" : 128 ,
72
- "BLOCK_SIZE_N" : 32 ,
73
- "BLOCK_SIZE_K" : 32 ,
74
- "GROUP_SIZE_M" : 8 ,
75
- },
76
- num_stages = 4 ,
77
- num_warps = 4 ,
78
- ),
79
- triton .Config (
80
- {
81
- "BLOCK_SIZE_M" : 64 ,
82
- "BLOCK_SIZE_N" : 32 ,
83
- "BLOCK_SIZE_K" : 32 ,
84
- "GROUP_SIZE_M" : 8 ,
85
- },
86
- num_stages = 5 ,
87
- num_warps = 2 ,
88
- ),
89
- triton .Config (
90
- {
91
- "BLOCK_SIZE_M" : 32 ,
92
- "BLOCK_SIZE_N" : 64 ,
93
- "BLOCK_SIZE_K" : 32 ,
94
- "GROUP_SIZE_M" : 8 ,
95
- },
96
- num_stages = 5 ,
97
- num_warps = 2 ,
98
- ),
99
- ],
18
+ configs = (
19
+ [
20
+ triton .Config (
21
+ {
22
+ "BLOCK_SIZE_M" : 64 ,
23
+ "BLOCK_SIZE_N" : 16 ,
24
+ "BLOCK_SIZE_K" : 256 ,
25
+ "GROUP_SIZE_M" : 8 ,
26
+ },
27
+ num_stages = 2 ,
28
+ num_warps = 4 ,
29
+ ),
30
+ triton .Config (
31
+ {
32
+ "BLOCK_SIZE_M" : 32 ,
33
+ "BLOCK_SIZE_N" : 16 ,
34
+ "BLOCK_SIZE_K" : 128 ,
35
+ "GROUP_SIZE_M" : 8 ,
36
+ },
37
+ num_stages = 2 ,
38
+ num_warps = 2 ,
39
+ ),
40
+ triton .Config (
41
+ {
42
+ "BLOCK_SIZE_M" : 128 ,
43
+ "BLOCK_SIZE_N" : 256 ,
44
+ "BLOCK_SIZE_K" : 32 ,
45
+ "GROUP_SIZE_M" : 8 ,
46
+ },
47
+ num_stages = 2 ,
48
+ num_warps = 8 ,
49
+ ),
50
+ triton .Config (
51
+ {
52
+ "BLOCK_SIZE_M" : 256 ,
53
+ "BLOCK_SIZE_N" : 256 ,
54
+ "BLOCK_SIZE_K" : 32 ,
55
+ "GROUP_SIZE_M" : 8 ,
56
+ },
57
+ num_stages = 2 ,
58
+ num_warps = 8 ,
59
+ ),
60
+ triton .Config (
61
+ {
62
+ "BLOCK_SIZE_M" : 128 ,
63
+ "BLOCK_SIZE_N" : 64 ,
64
+ "BLOCK_SIZE_K" : 128 ,
65
+ "GROUP_SIZE_M" : 8 ,
66
+ },
67
+ num_stages = 2 ,
68
+ num_warps = 8 ,
69
+ ),
70
+ triton .Config (
71
+ {
72
+ "BLOCK_SIZE_M" : 256 ,
73
+ "BLOCK_SIZE_N" : 128 ,
74
+ "BLOCK_SIZE_K" : 32 ,
75
+ "GROUP_SIZE_M" : 8 ,
76
+ },
77
+ num_stages = 2 ,
78
+ num_warps = 8 ,
79
+ ),
80
+ triton .Config (
81
+ {
82
+ "BLOCK_SIZE_M" : 32 ,
83
+ "BLOCK_SIZE_N" : 64 ,
84
+ "BLOCK_SIZE_K" : 128 ,
85
+ "GROUP_SIZE_M" : 8 ,
86
+ },
87
+ num_stages = 2 ,
88
+ num_warps = 2 ,
89
+ ),
90
+ triton .Config (
91
+ {
92
+ "BLOCK_SIZE_M" : 128 ,
93
+ "BLOCK_SIZE_N" : 128 ,
94
+ "BLOCK_SIZE_K" : 128 ,
95
+ "GROUP_SIZE_M" : 8 ,
96
+ },
97
+ num_stages = 2 ,
98
+ num_warps = 8 ,
99
+ ),
100
+ triton .Config (
101
+ {
102
+ "BLOCK_SIZE_M" : 128 ,
103
+ "BLOCK_SIZE_N" : 128 ,
104
+ "BLOCK_SIZE_K" : 64 ,
105
+ "GROUP_SIZE_M" : 8 ,
106
+ },
107
+ num_stages = 2 ,
108
+ num_warps = 2 ,
109
+ ),
110
+ triton .Config (
111
+ {
112
+ "BLOCK_SIZE_M" : 128 ,
113
+ "BLOCK_SIZE_N" : 128 ,
114
+ "BLOCK_SIZE_K" : 64 ,
115
+ "GROUP_SIZE_M" : 8 ,
116
+ },
117
+ num_stages = 2 ,
118
+ num_warps = 8 ,
119
+ ),
120
+ triton .Config (
121
+ {
122
+ "BLOCK_SIZE_M" : 128 ,
123
+ "BLOCK_SIZE_N" : 128 ,
124
+ "BLOCK_SIZE_K" : 64 ,
125
+ "GROUP_SIZE_M" : 8 ,
126
+ },
127
+ num_stages = 1 ,
128
+ num_warps = 4 ,
129
+ ),
130
+ triton .Config (
131
+ {
132
+ "BLOCK_SIZE_M" : 64 ,
133
+ "BLOCK_SIZE_N" : 64 ,
134
+ "BLOCK_SIZE_K" : 128 ,
135
+ "GROUP_SIZE_M" : 8 ,
136
+ },
137
+ num_stages = 2 ,
138
+ num_warps = 4 ,
139
+ ),
140
+ triton .Config (
141
+ {
142
+ "BLOCK_SIZE_M" : 64 ,
143
+ "BLOCK_SIZE_N" : 64 ,
144
+ "BLOCK_SIZE_K" : 64 ,
145
+ "GROUP_SIZE_M" : 8 ,
146
+ },
147
+ num_stages = 2 ,
148
+ num_warps = 4 ,
149
+ ),
150
+ triton .Config (
151
+ {
152
+ "BLOCK_SIZE_M" : 16 ,
153
+ "BLOCK_SIZE_N" : 16 ,
154
+ "BLOCK_SIZE_K" : 256 ,
155
+ "GROUP_SIZE_M" : 8 ,
156
+ },
157
+ num_stages = 2 ,
158
+ num_warps = 4 ,
159
+ ),
160
+ triton .Config (
161
+ {
162
+ "BLOCK_SIZE_M" : 128 ,
163
+ "BLOCK_SIZE_N" : 128 ,
164
+ "BLOCK_SIZE_K" : 64 ,
165
+ "GROUP_SIZE_M" : 8 ,
166
+ },
167
+ num_stages = 1 ,
168
+ num_warps = 8 ,
169
+ ),
170
+ triton .Config (
171
+ {
172
+ "BLOCK_SIZE_M" : 64 ,
173
+ "BLOCK_SIZE_N" : 128 ,
174
+ "BLOCK_SIZE_K" : 64 ,
175
+ "GROUP_SIZE_M" : 8 ,
176
+ },
177
+ num_stages = 2 ,
178
+ num_warps = 8 ,
179
+ ),
180
+ triton .Config (
181
+ {
182
+ "BLOCK_SIZE_M" : 16 ,
183
+ "BLOCK_SIZE_N" : 16 ,
184
+ "BLOCK_SIZE_K" : 128 ,
185
+ "GROUP_SIZE_M" : 8 ,
186
+ },
187
+ num_stages = 2 ,
188
+ num_warps = 4 ,
189
+ ),
190
+ ]
191
+ if torch .version .hip is not None
192
+ else [
193
+ triton .Config (
194
+ {
195
+ "BLOCK_SIZE_M" : 128 ,
196
+ "BLOCK_SIZE_N" : 256 ,
197
+ "BLOCK_SIZE_K" : 64 ,
198
+ "GROUP_SIZE_M" : 8 ,
199
+ },
200
+ num_stages = 3 ,
201
+ num_warps = 8 ,
202
+ ),
203
+ triton .Config (
204
+ {
205
+ "BLOCK_SIZE_M" : 64 ,
206
+ "BLOCK_SIZE_N" : 256 ,
207
+ "BLOCK_SIZE_K" : 32 ,
208
+ "GROUP_SIZE_M" : 8 ,
209
+ },
210
+ num_stages = 4 ,
211
+ num_warps = 4 ,
212
+ ),
213
+ triton .Config (
214
+ {
215
+ "BLOCK_SIZE_M" : 128 ,
216
+ "BLOCK_SIZE_N" : 128 ,
217
+ "BLOCK_SIZE_K" : 32 ,
218
+ "GROUP_SIZE_M" : 8 ,
219
+ },
220
+ num_stages = 4 ,
221
+ num_warps = 4 ,
222
+ ),
223
+ triton .Config (
224
+ {
225
+ "BLOCK_SIZE_M" : 128 ,
226
+ "BLOCK_SIZE_N" : 64 ,
227
+ "BLOCK_SIZE_K" : 32 ,
228
+ "GROUP_SIZE_M" : 8 ,
229
+ },
230
+ num_stages = 4 ,
231
+ num_warps = 4 ,
232
+ ),
233
+ triton .Config (
234
+ {
235
+ "BLOCK_SIZE_M" : 64 ,
236
+ "BLOCK_SIZE_N" : 128 ,
237
+ "BLOCK_SIZE_K" : 32 ,
238
+ "GROUP_SIZE_M" : 8 ,
239
+ },
240
+ num_stages = 4 ,
241
+ num_warps = 4 ,
242
+ ),
243
+ triton .Config (
244
+ {
245
+ "BLOCK_SIZE_M" : 128 ,
246
+ "BLOCK_SIZE_N" : 32 ,
247
+ "BLOCK_SIZE_K" : 32 ,
248
+ "GROUP_SIZE_M" : 8 ,
249
+ },
250
+ num_stages = 4 ,
251
+ num_warps = 4 ,
252
+ ),
253
+ triton .Config (
254
+ {
255
+ "BLOCK_SIZE_M" : 64 ,
256
+ "BLOCK_SIZE_N" : 32 ,
257
+ "BLOCK_SIZE_K" : 32 ,
258
+ "GROUP_SIZE_M" : 8 ,
259
+ },
260
+ num_stages = 5 ,
261
+ num_warps = 2 ,
262
+ ),
263
+ triton .Config (
264
+ {
265
+ "BLOCK_SIZE_M" : 32 ,
266
+ "BLOCK_SIZE_N" : 64 ,
267
+ "BLOCK_SIZE_K" : 32 ,
268
+ "GROUP_SIZE_M" : 8 ,
269
+ },
270
+ num_stages = 5 ,
271
+ num_warps = 2 ,
272
+ ),
273
+ ]
274
+ ),
100
275
key = ["M" , "N" , "K" ],
101
276
)
102
277
@triton .jit
You can’t perform that action at this time.
0 commit comments