Skip to content

Commit 58f3b1f

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add --op-collection option (#2503)
Summary: This PR add `--op-collection` to tritonbench. It can run multiple ops in defined operator collections. The default collection includes all ops not included in other collections. Operator collections are defined in `torchbenchmark/operators_collection/`. For each collection, you should define a `get_operators` function to return operators included in this collection. Pull Request resolved: #2503 Reviewed By: xuzhao9 Differential Revision: D64359380 Pulled By: FindHao fbshipit-source-id: c66dd254a3c8b70c112d9b7774482813e0236789
1 parent 9e670cd commit 58f3b1f

File tree

5 files changed

+123
-3
lines changed

5 files changed

+123
-3
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import importlib
2+
import pathlib
3+
from typing import List
4+
5+
OP_COLLECTION_PATH = "operators_collection"
6+
7+
8+
def list_operator_collections() -> List[str]:
9+
"""
10+
List the available operator collections.
11+
12+
This function retrieves the list of available operator collections by scanning the directories
13+
in the current path that contain an "__init__.py" file.
14+
15+
Returns:
16+
List[str]: A list of names of the available operator collections.
17+
"""
18+
p = pathlib.Path(__file__).parent
19+
# only load the directories that contain a "__init__.py" file
20+
collection_paths = sorted(
21+
str(child.absolute())
22+
for child in p.iterdir()
23+
if child.is_dir() and child.joinpath("__init__.py").exists()
24+
)
25+
filtered_collections = [pathlib.Path(path).name for path in collection_paths]
26+
return filtered_collections
27+
28+
29+
def list_operators_by_collection(op_collection: str = "default") -> List[str]:
30+
"""
31+
List the operators from the specified operator collections.
32+
33+
This function retrieves the list of operators from the specified operator collections.
34+
If the collection name is "all", it retrieves operators from all available collections.
35+
If the collection name is not specified, it defaults to the "default" collection.
36+
37+
Args:
38+
op_collection (str): Names of the operator collections to list operators from.
39+
It can be a single collection name or a comma-separated list of names.
40+
Special value "all" retrieves operators from all collections.
41+
42+
Returns:
43+
List[str]: A list of operator names from the specified collection(s).
44+
45+
Raises:
46+
ModuleNotFoundError: If the specified collection module is not found.
47+
AttributeError: If the specified collection module does not have a 'get_operators' function.
48+
"""
49+
50+
def _list_all_operators(collection_name: str):
51+
try:
52+
module_name = f".{collection_name}"
53+
module = importlib.import_module(module_name, package=__name__)
54+
if hasattr(module, "get_operators"):
55+
return module.get_operators()
56+
else:
57+
raise AttributeError(
58+
f"Module '{module_name}' does not have a 'get_operators' function"
59+
)
60+
except ModuleNotFoundError:
61+
raise ModuleNotFoundError(f"Module '{module_name}' not found")
62+
63+
if op_collection == "all":
64+
collection_names = list_operator_collections()
65+
else:
66+
collection_names = op_collection.split(",")
67+
68+
all_operators = []
69+
for collection_name in collection_names:
70+
all_operators.extend(_list_all_operators(collection_name))
71+
return all_operators
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from torchbenchmark.operators import list_operators
2+
3+
4+
def get_operators():
5+
return list_operators()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from torchbenchmark.operators_collection.all import get_operators as get_all_operators
2+
from torchbenchmark.operators_collection.liger import (
3+
get_operators as get_liger_operators,
4+
)
5+
6+
7+
def get_operators():
8+
"""
9+
Retrieve the list of operators for the default collection.
10+
11+
This function retrieves the list of operators for the default collection by
12+
comparing the operators from the 'all' collection and the 'liger' collection.
13+
It returns a list of operators that are present in the 'all' collection but
14+
not in the 'liger' collection.
15+
16+
In the future, if we add more operator collections, we will need to update
17+
this function to exclude desired operators in other collections.
18+
19+
other_collections = list_operator_collections()
20+
to_remove = set(other_collections).union(liger_operators)
21+
return [item for item in all_operators if item not in to_remove]
22+
23+
Returns:
24+
List[str]: A list of operator names for the default collection.
25+
"""
26+
all_operators = get_all_operators()
27+
liger_operators = get_liger_operators()
28+
return [item for item in all_operators if item not in liger_operators]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
liger_operators = ["FusedLinearCrossEntropy"]
2+
3+
4+
def get_operators():
5+
return liger_operators

userbenchmark/triton/run.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch import version as torch_version
88
from torchbenchmark.operator_loader import load_opbench_by_name_from_loader
99
from torchbenchmark.operators import load_opbench_by_name
10+
from torchbenchmark.operators_collection import list_operators_by_collection
1011

1112
from torchbenchmark.util.triton_op import (
1213
BenchmarkOperatorResult,
@@ -36,6 +37,13 @@ def get_parser(args=None):
3637
required=False,
3738
help="Operators to benchmark. Split with comma if multiple.",
3839
)
40+
parser.add_argument(
41+
"--op-collection",
42+
default="default",
43+
type=str,
44+
help="Operator collections to benchmark. Split with comma."
45+
" It is conflict with --op. Choices: [default, liger, all]",
46+
)
3947
parser.add_argument(
4048
"--mode",
4149
choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"],
@@ -158,8 +166,10 @@ def get_parser(args=None):
158166
args, extra_args = parser.parse_known_args(args)
159167
if args.op and args.ci:
160168
parser.error("cannot specify operator when in CI mode")
161-
elif not args.op and not args.ci:
162-
parser.error("must specify operator when not in CI mode")
169+
if not args.op and not args.op_collection:
170+
print(
171+
"Neither operator nor operator collection is specified. Running all operators in the default collection."
172+
)
163173
return parser
164174

165175

@@ -221,7 +231,8 @@ def run(args: List[str] = []):
221231
if args.op:
222232
ops = args.op.split(",")
223233
else:
224-
ops = []
234+
ops = list_operators_by_collection(args.op_collection)
235+
225236
with gpu_lockdown(args.gpu_lockdown):
226237
for op in ops:
227238
args.op = op

0 commit comments

Comments
 (0)