Skip to content

Commit 0a9cd8f

Browse files
adamomainzfacebook-github-bot
authored andcommitted
adding aggregates to servicelab
Summary: current aggregation does not seem to be working as expected. Adding another aggregation field before changing the previous over Reviewed By: xuzhao9 Differential Revision: D64616616 fbshipit-source-id: 676f09035e0d4427e9b60e9ed8f8c790782f0aec
1 parent f7dc0c7 commit 0a9cd8f

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

torchbenchmark/util/triton_op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,23 @@ def userbenchmark_dict(self) -> Dict[str, Any]:
351351
# tritonbench_{op_name}_{op_mode}[{x_val}-{provider}-{metric}]
352352
userbenchmark_metrics_dict = {}
353353
headers, table = self._table()
354+
num_rows = len(table)
355+
agg_data = {}
354356
for row in table:
355357
x_val = row[0]
358+
356359
for ind, value in enumerate(row[1:]):
357360
header = headers[ind + 1]
358361
provider, _dash, metrics = header.partition("-")
359362
metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}"
360363
userbenchmark_metrics_dict[metric_name] = value
364+
agg_metric_name = (
365+
f"tritonbench_{self.op_name}_{self.op_mode}[{provider}]_{metrics}"
366+
)
367+
agg_data[agg_metric_name] = agg_data.get(agg_metric_name, 0) + value
368+
final_agg_data = {k: v / num_rows for k, v in agg_data.items()}
369+
userbenchmark_metrics_dict.update(final_agg_data)
370+
361371
return userbenchmark_metrics_dict
362372

363373
def get_y_vals(self, x_val, provider, metric_name: str):

0 commit comments

Comments
 (0)