Skip to content

Commit 00c9b9e

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Make sure all ci-enabled impls are in the output
Summary: In the CI, we will check that all registered impls are available in the output, unless they are specified as`ci=False`. We add the `ci=` flag because right now we don't have lazy imports to import optional backend modules while we want different behavior between flags `enabled` and `ci`. For `enabled` flag, we want "best-effort". If a module is not available (e.g. flash attention 3 is not available on A100), we should check if it is not available, then skip it automatically instead of error out for the best user experience. For `ci` flag, we want to make sure that things are already setup in fbcode CI, and if flash attention 3 is not available, it is a red flag and we have to report it in the unit test. Reviewed By: bertmaher Differential Revision: D64473609 fbshipit-source-id: 320255f73942705038d50aac1f14d318b62a4765
1 parent eec8612 commit 00c9b9e

File tree

4 files changed

+70
-35
lines changed

4 files changed

+70
-35
lines changed

torchbenchmark/operators/flash_attention/operator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def xformers_preprocess(
287287
)
288288
return fhma_input
289289

290-
@register_benchmark(enabled=False)
290+
@register_benchmark(ci=False)
291291
def xformers(
292292
self,
293293
q: torch.Tensor,
@@ -298,7 +298,7 @@ def xformers(
298298
xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp
299299
return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False)
300300

301-
@register_benchmark(enabled=False)
301+
@register_benchmark(ci=False)
302302
def xformers_splitk(
303303
self,
304304
q: torch.Tensor,
@@ -316,7 +316,7 @@ def colfax_cutlass_preprocess(self, q, k, v):
316316
torch.transpose(v, 1, 2),
317317
)
318318

319-
@register_benchmark(enabled=False)
319+
@register_benchmark(ci=False)
320320
def colfax_cutlass(self, q, k, v):
321321
default_scale = 1.0 / math.sqrt(float(self.D_HEAD))
322322
colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v)
@@ -330,7 +330,7 @@ def colfax_cutlass(self, q, k, v):
330330
default_scale,
331331
)
332332

333-
@register_benchmark(enabled=False)
333+
@register_benchmark(enabled=(tk_fwd is not None))
334334
def tk(self, q, k, v):
335335
o = torch.zeros_like(v)
336336

@@ -343,7 +343,7 @@ def tk_dispatcher():
343343

344344
return tk_dispatcher
345345

346-
@register_benchmark(enabled=False, label=f"cudnn_{torch.backends.cudnn.version()}")
346+
@register_benchmark(ci=False, label=f"cudnn_{torch.backends.cudnn.version()}")
347347
def cudnn(self, q, k, v):
348348
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
349349

torchbenchmark/operators/fp8_fused_quant_gemm_rowwise/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
silu_mul,
1818
silu_mul_fp8_rowwise_quant,
1919
)
20+
from gen_ai.llm_inference.fb.llm.quantization.quantize import quantize_fp8_row
2021

2122
HAS_FB_IMPORT = True
2223
except ImportError:
2324
HAS_FB_IMPORT = False
2425

26+
2527
from torchbenchmark.util.triton_op import (
2628
BenchmarkOperator,
2729
BenchmarkOperatorMetrics,
@@ -33,7 +35,7 @@
3335

3436
def parse_args(args: List[str]) -> argparse.Namespace:
3537
parser = argparse.ArgumentParser(
36-
description="TorchBench FP8 fused quant gemm rowwise operator Benchmark"
38+
description="Tritonbench FP8 fused quant gemm rowwise operator Benchmark"
3739
)
3840
parser.add_argument("--m", type=int)
3941
parser.add_argument("--n", type=int)

torchbenchmark/operators/gemm/operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,15 @@ def triton_persistent_matmul(self, a, b, bias) -> Callable:
162162
else:
163163
return lambda: matmul_persistent(a, b)
164164

165-
@register_benchmark(enabled=not IS_FBCODE)
165+
@register_benchmark(ci=not IS_FBCODE)
166166
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
167167
b = b.T.contiguous()
168168
if not bias == None:
169169
return lambda: matmul_tma_persistent(a, b) + bias
170170
else:
171171
return lambda: matmul_tma_persistent(a, b)
172172

173-
@register_benchmark(enabled=not IS_FBCODE)
173+
@register_benchmark(ci=not IS_FBCODE)
174174
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
175175
b = b.T.contiguous()
176176
if not bias == None:
@@ -198,7 +198,7 @@ def hstu_triton_matmul(self, a, b, bias) -> Callable:
198198
else:
199199
return lambda: hstu_triton_matmul(a, b)
200200

201-
@register_benchmark(enabled=bool(colfax_gemm))
201+
@register_benchmark(ci=False) # colfax_cutlass build is broken on CUDA 12.4
202202
def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
203203
assert colfax_gemm, f"colfax_gemm operator is not available."
204204
if not bias == None:

torchbenchmark/util/triton_op.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,27 @@
3737

3838
logger = logging.getLogger(__name__)
3939

40+
41+
@dataclass
42+
class BenchmarkOperatorBackend:
43+
# backend name
44+
name: str
45+
# backend label
46+
label: str
47+
# baseline
48+
baseline: bool = False
49+
# enabled
50+
enabled: bool = True
51+
# need to be tested in ci
52+
# ci = False implies enabled = False
53+
ci: bool = True
54+
55+
4056
IS_FBCODE = not hasattr(torch.version, "git_version")
4157
DEFAULT_WARMUP = 25
4258
DEFAULT_RUN_ITERS = 100
4359
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
44-
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, str]] = {}
60+
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
4561
ENABLED_BENCHMARKS: Dict[str, List[str]] = {}
4662
REGISTERED_METRICS: Dict[str, List[str]] = {}
4763
REGISTERED_X_VALS: Dict[str, str] = {}
@@ -220,6 +236,7 @@ class BenchmarkOperatorResult:
220236
op_name: str
221237
op_mode: str
222238
metrics: List[str]
239+
# Tuple: (x_val, Dict[impl_name, BenchmarkOperatorMetrics])
223240
result: List[Tuple[Any, Dict[str, BenchmarkOperatorMetrics]]]
224241
_result_dict: Optional[Dict[Number, Dict[str, BenchmarkOperatorMetrics]]] = None
225242

@@ -230,61 +247,62 @@ def _table(self):
230247
if len(self.result) == 0:
231248
return headers, table
232249
y_val = self.result[0][1]
233-
y_val_keys = list(y_val.keys())
250+
backends = list(y_val.keys())
234251
# move the baseline benchmark to the front of the list if exists
235252
if (
236253
self.op_name in BASELINE_BENCHMARKS
237-
and BASELINE_BENCHMARKS[self.op_name] in y_val_keys
254+
and BASELINE_BENCHMARKS[self.op_name] in backends
238255
):
239-
y_val_keys.insert(
240-
0, y_val_keys.pop(y_val_keys.index(BASELINE_BENCHMARKS[self.op_name]))
256+
backends.insert(
257+
0, backends.pop(backends.index(BASELINE_BENCHMARKS[self.op_name]))
241258
)
242-
y_val_keys = [(x, REGISTERED_BENCHMARKS[self.op_name][x]) for x in y_val_keys]
243259
key_metrics = {}
244260
# Add header for x_only_metrics
245261
x_only_metrics = sorted(
246262
[metric for metric in self.metrics if metric in X_ONLY_METRICS]
247263
)
248264
headers.extend(x_only_metrics)
249-
for k, label in y_val_keys:
265+
for backend in backends:
266+
label = REGISTERED_BENCHMARKS[self.op_name][backend].label
250267

251-
def select_metric(m):
268+
def select_metric(backend, m):
252269
if m in x_only_metrics:
253270
return False
254271
if (
255272
m in BASELINE_SKIP_METRICS
256-
and k == BASELINE_BENCHMARKS[self.op_name]
273+
and backend == BASELINE_BENCHMARKS[self.op_name]
257274
):
258275
return False
259276
return True
260277

261-
key_metrics[k] = sorted(filter(select_metric, self.metrics))
262-
for metric in key_metrics[k]:
278+
key_metrics[backend] = [
279+
metric for metric in self.metrics if select_metric(backend, metric)
280+
]
281+
for metric in key_metrics[backend]:
263282
# add extra metrics
264283
headers.append(f"{label}-{metric}")
265284
# generate rows
266285
for x_val, y_val in self.result:
267286
row = []
268287
row.append(x_val)
269-
# Append x_val_only metrics
288+
# Append x_only metrics
270289
for x_only_metric in x_only_metrics:
271-
x_only_metric_dict = asdict(
272-
y_val[y_val_keys[0][0]]
273-
) # retrieve canonical name for metric function, where y_val_keys[0] = (canonical name, customized label name)
290+
# retrieve x_only metrics from the first backend metrics
291+
x_only_metric_dict = asdict(y_val[backends[0]])
274292
if (
275293
"extra_metrics" in x_only_metric_dict
276294
and x_only_metric in x_only_metric_dict["extra_metrics"]
277295
):
278296
row.append(x_only_metric_dict["extra_metrics"][x_only_metric])
279297
else:
280298
row.append(x_only_metric_dict[x_only_metric])
281-
for k, _label in y_val_keys:
282-
metrics_dict = asdict(y_val[k])
299+
for backend in backends:
300+
metrics_dict = asdict(y_val[backend])
283301
if metrics_dict["error_msg"]:
284302
row.append(metrics_dict["error_msg"])
285-
row.extend([None] * (len(key_metrics[k]) - 1))
303+
row.extend([None] * (len(key_metrics[backend]) - 1))
286304
continue
287-
for metric in key_metrics[k]:
305+
for metric in key_metrics[backend]:
288306
_metrics_dict = (
289307
metrics_dict["extra_metrics"]
290308
if metric in metrics_dict["extra_metrics"]
@@ -384,18 +402,26 @@ def _inner(self, *args, **kwargs):
384402

385403

386404
def register_benchmark(
387-
baseline: bool = False, enabled: bool = True, label: Optional[str] = None
405+
baseline: bool = False,
406+
enabled: bool = True,
407+
ci: bool = True,
408+
label: Optional[str] = None,
388409
):
389410
def decorator(function):
390411
operator_name = _find_op_name_from_module_path(function.__module__)
412+
backend_config = BenchmarkOperatorBackend(
413+
name=function.__name__,
414+
label=label if label else function.__name__,
415+
baseline=baseline,
416+
enabled=enabled if ci else False,
417+
ci=ci,
418+
)
391419
if not operator_name in REGISTERED_BENCHMARKS:
392420
REGISTERED_BENCHMARKS[operator_name] = OrderedDict()
393-
REGISTERED_BENCHMARKS[operator_name][function.__name__] = (
394-
function.__name__ if not label else label
395-
)
396-
if baseline:
421+
REGISTERED_BENCHMARKS[operator_name][function.__name__] = backend_config
422+
if backend_config.baseline:
397423
BASELINE_BENCHMARKS[operator_name] = function.__name__
398-
if enabled:
424+
if backend_config.enabled:
399425
if not operator_name in ENABLED_BENCHMARKS:
400426
ENABLED_BENCHMARKS[operator_name] = []
401427
ENABLED_BENCHMARKS[operator_name].append(function.__name__)
@@ -414,6 +440,7 @@ def register_benchmark_mannually(
414440
baseline: bool = False,
415441
enabled: bool = True,
416442
label: Optional[str] = None,
443+
ci: bool = True,
417444
):
418445
"""
419446
Manually register a benchmark function for a given operator.
@@ -435,7 +462,13 @@ def register_benchmark_mannually(
435462
"""
436463
if not operator_name in REGISTERED_BENCHMARKS:
437464
REGISTERED_BENCHMARKS[operator_name] = OrderedDict()
438-
REGISTERED_BENCHMARKS[operator_name][func_name] = func_name if not label else label
465+
REGISTERED_BENCHMARKS[operator_name][func_name] = BenchmarkOperatorBackend(
466+
name=function.__name__,
467+
label=label if label else function.__name__,
468+
baseline=baseline,
469+
enabled=enabled,
470+
ci=ci,
471+
)
439472
if baseline:
440473
BASELINE_BENCHMARKS[operator_name] = func_name
441474
if enabled:

0 commit comments

Comments
 (0)