37
37
38
38
logger = logging .getLogger (__name__ )
39
39
40
+
41
+ @dataclass
42
+ class BenchmarkOperatorBackend :
43
+ # backend name
44
+ name : str
45
+ # backend label
46
+ label : str
47
+ # baseline
48
+ baseline : bool = False
49
+ # enabled
50
+ enabled : bool = True
51
+ # need to be tested in ci
52
+ # ci = False implies enabled = False
53
+ ci : bool = True
54
+
55
+
40
56
IS_FBCODE = not hasattr (torch .version , "git_version" )
41
57
DEFAULT_WARMUP = 25
42
58
DEFAULT_RUN_ITERS = 100
43
59
DEFAULT_QUANTILES = [0.5 , 0.1 , 0.9 ]
44
- REGISTERED_BENCHMARKS : Dict [str , OrderedDict [str , str ]] = {}
60
+ REGISTERED_BENCHMARKS : Dict [str , OrderedDict [str , BenchmarkOperatorBackend ]] = {}
45
61
ENABLED_BENCHMARKS : Dict [str , List [str ]] = {}
46
62
REGISTERED_METRICS : Dict [str , List [str ]] = {}
47
63
REGISTERED_X_VALS : Dict [str , str ] = {}
@@ -220,6 +236,7 @@ class BenchmarkOperatorResult:
220
236
op_name : str
221
237
op_mode : str
222
238
metrics : List [str ]
239
+ # Tuple: (x_val, Dict[impl_name, BenchmarkOperatorMetrics])
223
240
result : List [Tuple [Any , Dict [str , BenchmarkOperatorMetrics ]]]
224
241
_result_dict : Optional [Dict [Number , Dict [str , BenchmarkOperatorMetrics ]]] = None
225
242
@@ -230,61 +247,62 @@ def _table(self):
230
247
if len (self .result ) == 0 :
231
248
return headers , table
232
249
y_val = self .result [0 ][1 ]
233
- y_val_keys = list (y_val .keys ())
250
+ backends = list (y_val .keys ())
234
251
# move the baseline benchmark to the front of the list if exists
235
252
if (
236
253
self .op_name in BASELINE_BENCHMARKS
237
- and BASELINE_BENCHMARKS [self .op_name ] in y_val_keys
254
+ and BASELINE_BENCHMARKS [self .op_name ] in backends
238
255
):
239
- y_val_keys .insert (
240
- 0 , y_val_keys .pop (y_val_keys .index (BASELINE_BENCHMARKS [self .op_name ]))
256
+ backends .insert (
257
+ 0 , backends .pop (backends .index (BASELINE_BENCHMARKS [self .op_name ]))
241
258
)
242
- y_val_keys = [(x , REGISTERED_BENCHMARKS [self .op_name ][x ]) for x in y_val_keys ]
243
259
key_metrics = {}
244
260
# Add header for x_only_metrics
245
261
x_only_metrics = sorted (
246
262
[metric for metric in self .metrics if metric in X_ONLY_METRICS ]
247
263
)
248
264
headers .extend (x_only_metrics )
249
- for k , label in y_val_keys :
265
+ for backend in backends :
266
+ label = REGISTERED_BENCHMARKS [self .op_name ][backend ].label
250
267
251
- def select_metric (m ):
268
+ def select_metric (backend , m ):
252
269
if m in x_only_metrics :
253
270
return False
254
271
if (
255
272
m in BASELINE_SKIP_METRICS
256
- and k == BASELINE_BENCHMARKS [self .op_name ]
273
+ and backend == BASELINE_BENCHMARKS [self .op_name ]
257
274
):
258
275
return False
259
276
return True
260
277
261
- key_metrics [k ] = sorted (filter (select_metric , self .metrics ))
262
- for metric in key_metrics [k ]:
278
+ key_metrics [backend ] = [
279
+ metric for metric in self .metrics if select_metric (backend , metric )
280
+ ]
281
+ for metric in key_metrics [backend ]:
263
282
# add extra metrics
264
283
headers .append (f"{ label } -{ metric } " )
265
284
# generate rows
266
285
for x_val , y_val in self .result :
267
286
row = []
268
287
row .append (x_val )
269
- # Append x_val_only metrics
288
+ # Append x_only metrics
270
289
for x_only_metric in x_only_metrics :
271
- x_only_metric_dict = asdict (
272
- y_val [y_val_keys [0 ][0 ]]
273
- ) # retrieve canonical name for metric function, where y_val_keys[0] = (canonical name, customized label name)
290
+ # retrieve x_only metrics from the first backend metrics
291
+ x_only_metric_dict = asdict (y_val [backends [0 ]])
274
292
if (
275
293
"extra_metrics" in x_only_metric_dict
276
294
and x_only_metric in x_only_metric_dict ["extra_metrics" ]
277
295
):
278
296
row .append (x_only_metric_dict ["extra_metrics" ][x_only_metric ])
279
297
else :
280
298
row .append (x_only_metric_dict [x_only_metric ])
281
- for k , _label in y_val_keys :
282
- metrics_dict = asdict (y_val [k ])
299
+ for backend in backends :
300
+ metrics_dict = asdict (y_val [backend ])
283
301
if metrics_dict ["error_msg" ]:
284
302
row .append (metrics_dict ["error_msg" ])
285
- row .extend ([None ] * (len (key_metrics [k ]) - 1 ))
303
+ row .extend ([None ] * (len (key_metrics [backend ]) - 1 ))
286
304
continue
287
- for metric in key_metrics [k ]:
305
+ for metric in key_metrics [backend ]:
288
306
_metrics_dict = (
289
307
metrics_dict ["extra_metrics" ]
290
308
if metric in metrics_dict ["extra_metrics" ]
@@ -384,18 +402,26 @@ def _inner(self, *args, **kwargs):
384
402
385
403
386
404
def register_benchmark (
387
- baseline : bool = False , enabled : bool = True , label : Optional [str ] = None
405
+ baseline : bool = False ,
406
+ enabled : bool = True ,
407
+ ci : bool = True ,
408
+ label : Optional [str ] = None ,
388
409
):
389
410
def decorator (function ):
390
411
operator_name = _find_op_name_from_module_path (function .__module__ )
412
+ backend_config = BenchmarkOperatorBackend (
413
+ name = function .__name__ ,
414
+ label = label if label else function .__name__ ,
415
+ baseline = baseline ,
416
+ enabled = enabled if ci else False ,
417
+ ci = ci ,
418
+ )
391
419
if not operator_name in REGISTERED_BENCHMARKS :
392
420
REGISTERED_BENCHMARKS [operator_name ] = OrderedDict ()
393
- REGISTERED_BENCHMARKS [operator_name ][function .__name__ ] = (
394
- function .__name__ if not label else label
395
- )
396
- if baseline :
421
+ REGISTERED_BENCHMARKS [operator_name ][function .__name__ ] = backend_config
422
+ if backend_config .baseline :
397
423
BASELINE_BENCHMARKS [operator_name ] = function .__name__
398
- if enabled :
424
+ if backend_config . enabled :
399
425
if not operator_name in ENABLED_BENCHMARKS :
400
426
ENABLED_BENCHMARKS [operator_name ] = []
401
427
ENABLED_BENCHMARKS [operator_name ].append (function .__name__ )
@@ -414,6 +440,7 @@ def register_benchmark_mannually(
414
440
baseline : bool = False ,
415
441
enabled : bool = True ,
416
442
label : Optional [str ] = None ,
443
+ ci : bool = True ,
417
444
):
418
445
"""
419
446
Manually register a benchmark function for a given operator.
@@ -435,7 +462,13 @@ def register_benchmark_mannually(
435
462
"""
436
463
if not operator_name in REGISTERED_BENCHMARKS :
437
464
REGISTERED_BENCHMARKS [operator_name ] = OrderedDict ()
438
- REGISTERED_BENCHMARKS [operator_name ][func_name ] = func_name if not label else label
465
+ REGISTERED_BENCHMARKS [operator_name ][func_name ] = BenchmarkOperatorBackend (
466
+ name = function .__name__ ,
467
+ label = label if label else function .__name__ ,
468
+ baseline = baseline ,
469
+ enabled = enabled ,
470
+ ci = ci ,
471
+ )
439
472
if baseline :
440
473
BASELINE_BENCHMARKS [operator_name ] = func_name
441
474
if enabled :
0 commit comments