Skip to content

Commit 672ee07

Browse files
nmacchionifacebook-github-bot
authored andcommitted
tune tritonbench gemm
Summary: adding better tunings for mi300, ~150tflops -> ~250tflops Reviewed By: adamomainz Differential Revision: D65581507 fbshipit-source-id: b4e780d42f708924b5a2948b13f674b921a11294
1 parent 06d867a commit 672ee07

File tree

1 file changed

+257
-82
lines changed

1 file changed

+257
-82
lines changed

torchbenchmark/operators/gemm/triton_matmul.py

Lines changed: 257 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -15,88 +15,263 @@
1515
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
1616
# provided configs
1717
@triton.autotune(
18-
configs=[
19-
triton.Config(
20-
{
21-
"BLOCK_SIZE_M": 128,
22-
"BLOCK_SIZE_N": 256,
23-
"BLOCK_SIZE_K": 64,
24-
"GROUP_SIZE_M": 8,
25-
},
26-
num_stages=3,
27-
num_warps=8,
28-
),
29-
triton.Config(
30-
{
31-
"BLOCK_SIZE_M": 64,
32-
"BLOCK_SIZE_N": 256,
33-
"BLOCK_SIZE_K": 32,
34-
"GROUP_SIZE_M": 8,
35-
},
36-
num_stages=4,
37-
num_warps=4,
38-
),
39-
triton.Config(
40-
{
41-
"BLOCK_SIZE_M": 128,
42-
"BLOCK_SIZE_N": 128,
43-
"BLOCK_SIZE_K": 32,
44-
"GROUP_SIZE_M": 8,
45-
},
46-
num_stages=4,
47-
num_warps=4,
48-
),
49-
triton.Config(
50-
{
51-
"BLOCK_SIZE_M": 128,
52-
"BLOCK_SIZE_N": 64,
53-
"BLOCK_SIZE_K": 32,
54-
"GROUP_SIZE_M": 8,
55-
},
56-
num_stages=4,
57-
num_warps=4,
58-
),
59-
triton.Config(
60-
{
61-
"BLOCK_SIZE_M": 64,
62-
"BLOCK_SIZE_N": 128,
63-
"BLOCK_SIZE_K": 32,
64-
"GROUP_SIZE_M": 8,
65-
},
66-
num_stages=4,
67-
num_warps=4,
68-
),
69-
triton.Config(
70-
{
71-
"BLOCK_SIZE_M": 128,
72-
"BLOCK_SIZE_N": 32,
73-
"BLOCK_SIZE_K": 32,
74-
"GROUP_SIZE_M": 8,
75-
},
76-
num_stages=4,
77-
num_warps=4,
78-
),
79-
triton.Config(
80-
{
81-
"BLOCK_SIZE_M": 64,
82-
"BLOCK_SIZE_N": 32,
83-
"BLOCK_SIZE_K": 32,
84-
"GROUP_SIZE_M": 8,
85-
},
86-
num_stages=5,
87-
num_warps=2,
88-
),
89-
triton.Config(
90-
{
91-
"BLOCK_SIZE_M": 32,
92-
"BLOCK_SIZE_N": 64,
93-
"BLOCK_SIZE_K": 32,
94-
"GROUP_SIZE_M": 8,
95-
},
96-
num_stages=5,
97-
num_warps=2,
98-
),
99-
],
18+
configs=(
19+
[
20+
triton.Config(
21+
{
22+
"BLOCK_SIZE_M": 64,
23+
"BLOCK_SIZE_N": 16,
24+
"BLOCK_SIZE_K": 256,
25+
"GROUP_SIZE_M": 8,
26+
},
27+
num_stages=2,
28+
num_warps=4,
29+
),
30+
triton.Config(
31+
{
32+
"BLOCK_SIZE_M": 32,
33+
"BLOCK_SIZE_N": 16,
34+
"BLOCK_SIZE_K": 128,
35+
"GROUP_SIZE_M": 8,
36+
},
37+
num_stages=2,
38+
num_warps=2,
39+
),
40+
triton.Config(
41+
{
42+
"BLOCK_SIZE_M": 128,
43+
"BLOCK_SIZE_N": 256,
44+
"BLOCK_SIZE_K": 32,
45+
"GROUP_SIZE_M": 8,
46+
},
47+
num_stages=2,
48+
num_warps=8,
49+
),
50+
triton.Config(
51+
{
52+
"BLOCK_SIZE_M": 256,
53+
"BLOCK_SIZE_N": 256,
54+
"BLOCK_SIZE_K": 32,
55+
"GROUP_SIZE_M": 8,
56+
},
57+
num_stages=2,
58+
num_warps=8,
59+
),
60+
triton.Config(
61+
{
62+
"BLOCK_SIZE_M": 128,
63+
"BLOCK_SIZE_N": 64,
64+
"BLOCK_SIZE_K": 128,
65+
"GROUP_SIZE_M": 8,
66+
},
67+
num_stages=2,
68+
num_warps=8,
69+
),
70+
triton.Config(
71+
{
72+
"BLOCK_SIZE_M": 256,
73+
"BLOCK_SIZE_N": 128,
74+
"BLOCK_SIZE_K": 32,
75+
"GROUP_SIZE_M": 8,
76+
},
77+
num_stages=2,
78+
num_warps=8,
79+
),
80+
triton.Config(
81+
{
82+
"BLOCK_SIZE_M": 32,
83+
"BLOCK_SIZE_N": 64,
84+
"BLOCK_SIZE_K": 128,
85+
"GROUP_SIZE_M": 8,
86+
},
87+
num_stages=2,
88+
num_warps=2,
89+
),
90+
triton.Config(
91+
{
92+
"BLOCK_SIZE_M": 128,
93+
"BLOCK_SIZE_N": 128,
94+
"BLOCK_SIZE_K": 128,
95+
"GROUP_SIZE_M": 8,
96+
},
97+
num_stages=2,
98+
num_warps=8,
99+
),
100+
triton.Config(
101+
{
102+
"BLOCK_SIZE_M": 128,
103+
"BLOCK_SIZE_N": 128,
104+
"BLOCK_SIZE_K": 64,
105+
"GROUP_SIZE_M": 8,
106+
},
107+
num_stages=2,
108+
num_warps=2,
109+
),
110+
triton.Config(
111+
{
112+
"BLOCK_SIZE_M": 128,
113+
"BLOCK_SIZE_N": 128,
114+
"BLOCK_SIZE_K": 64,
115+
"GROUP_SIZE_M": 8,
116+
},
117+
num_stages=2,
118+
num_warps=8,
119+
),
120+
triton.Config(
121+
{
122+
"BLOCK_SIZE_M": 128,
123+
"BLOCK_SIZE_N": 128,
124+
"BLOCK_SIZE_K": 64,
125+
"GROUP_SIZE_M": 8,
126+
},
127+
num_stages=1,
128+
num_warps=4,
129+
),
130+
triton.Config(
131+
{
132+
"BLOCK_SIZE_M": 64,
133+
"BLOCK_SIZE_N": 64,
134+
"BLOCK_SIZE_K": 128,
135+
"GROUP_SIZE_M": 8,
136+
},
137+
num_stages=2,
138+
num_warps=4,
139+
),
140+
triton.Config(
141+
{
142+
"BLOCK_SIZE_M": 64,
143+
"BLOCK_SIZE_N": 64,
144+
"BLOCK_SIZE_K": 64,
145+
"GROUP_SIZE_M": 8,
146+
},
147+
num_stages=2,
148+
num_warps=4,
149+
),
150+
triton.Config(
151+
{
152+
"BLOCK_SIZE_M": 16,
153+
"BLOCK_SIZE_N": 16,
154+
"BLOCK_SIZE_K": 256,
155+
"GROUP_SIZE_M": 8,
156+
},
157+
num_stages=2,
158+
num_warps=4,
159+
),
160+
triton.Config(
161+
{
162+
"BLOCK_SIZE_M": 128,
163+
"BLOCK_SIZE_N": 128,
164+
"BLOCK_SIZE_K": 64,
165+
"GROUP_SIZE_M": 8,
166+
},
167+
num_stages=1,
168+
num_warps=8,
169+
),
170+
triton.Config(
171+
{
172+
"BLOCK_SIZE_M": 64,
173+
"BLOCK_SIZE_N": 128,
174+
"BLOCK_SIZE_K": 64,
175+
"GROUP_SIZE_M": 8,
176+
},
177+
num_stages=2,
178+
num_warps=8,
179+
),
180+
triton.Config(
181+
{
182+
"BLOCK_SIZE_M": 16,
183+
"BLOCK_SIZE_N": 16,
184+
"BLOCK_SIZE_K": 128,
185+
"GROUP_SIZE_M": 8,
186+
},
187+
num_stages=2,
188+
num_warps=4,
189+
),
190+
]
191+
if torch.version.hip is not None
192+
else [
193+
triton.Config(
194+
{
195+
"BLOCK_SIZE_M": 128,
196+
"BLOCK_SIZE_N": 256,
197+
"BLOCK_SIZE_K": 64,
198+
"GROUP_SIZE_M": 8,
199+
},
200+
num_stages=3,
201+
num_warps=8,
202+
),
203+
triton.Config(
204+
{
205+
"BLOCK_SIZE_M": 64,
206+
"BLOCK_SIZE_N": 256,
207+
"BLOCK_SIZE_K": 32,
208+
"GROUP_SIZE_M": 8,
209+
},
210+
num_stages=4,
211+
num_warps=4,
212+
),
213+
triton.Config(
214+
{
215+
"BLOCK_SIZE_M": 128,
216+
"BLOCK_SIZE_N": 128,
217+
"BLOCK_SIZE_K": 32,
218+
"GROUP_SIZE_M": 8,
219+
},
220+
num_stages=4,
221+
num_warps=4,
222+
),
223+
triton.Config(
224+
{
225+
"BLOCK_SIZE_M": 128,
226+
"BLOCK_SIZE_N": 64,
227+
"BLOCK_SIZE_K": 32,
228+
"GROUP_SIZE_M": 8,
229+
},
230+
num_stages=4,
231+
num_warps=4,
232+
),
233+
triton.Config(
234+
{
235+
"BLOCK_SIZE_M": 64,
236+
"BLOCK_SIZE_N": 128,
237+
"BLOCK_SIZE_K": 32,
238+
"GROUP_SIZE_M": 8,
239+
},
240+
num_stages=4,
241+
num_warps=4,
242+
),
243+
triton.Config(
244+
{
245+
"BLOCK_SIZE_M": 128,
246+
"BLOCK_SIZE_N": 32,
247+
"BLOCK_SIZE_K": 32,
248+
"GROUP_SIZE_M": 8,
249+
},
250+
num_stages=4,
251+
num_warps=4,
252+
),
253+
triton.Config(
254+
{
255+
"BLOCK_SIZE_M": 64,
256+
"BLOCK_SIZE_N": 32,
257+
"BLOCK_SIZE_K": 32,
258+
"GROUP_SIZE_M": 8,
259+
},
260+
num_stages=5,
261+
num_warps=2,
262+
),
263+
triton.Config(
264+
{
265+
"BLOCK_SIZE_M": 32,
266+
"BLOCK_SIZE_N": 64,
267+
"BLOCK_SIZE_K": 32,
268+
"GROUP_SIZE_M": 8,
269+
},
270+
num_stages=5,
271+
num_warps=2,
272+
),
273+
]
274+
),
100275
key=["M", "N", "K"],
101276
)
102277
@triton.jit

0 commit comments

Comments
 (0)