Skip to content

Commit d933ced

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add doc for adding custom ops (#2509)
Summary: Add documentation for adding custom ops. Pull Request resolved: #2509 Reviewed By: xuzhao9 Differential Revision: D64497281 Pulled By: FindHao fbshipit-source-id: 20f4096ebbce53c7d9a713cacbde016c521aa7c3
1 parent 2feadb6 commit d933ced

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

torchbenchmark/operators/readme.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
## How to add a custom operator benchmark
2+
3+
1. Create a new folder in the `operators` directory.
4+
2. Add an `operator.py` and `__init__.py` file to the new folder.
5+
3. Implement the `Operator` class.
6+
4. Register the operator benchmarks in the `operator.py` file.
7+
8+
### Example
9+
10+
```
11+
operators/
12+
my_operator/
13+
__init__.py
14+
operator.py
15+
```
16+
17+
## `__init__.py`
18+
19+
The `__init__.py` file only needs to import the operator to trigger the registration of the benchmarks.
20+
21+
```
22+
from .operator import Operator
23+
```
24+
25+
## `operator.py`
26+
27+
The `operator.py` file needs to implement the following:
28+
29+
1. `Operator` class: This class should inherit from `BenchmarkOperator`.
30+
2. `get_input_iter`: This method should return an iterator of input examples for the
31+
operator.
32+
3. `@register_benchmark`: This decorator should be used to register the benchmarks for
33+
the operator.
34+
4. `get_bwd_fn`: This method should return a callable that performs the backward pass
35+
for the operator when needed.
36+
5. `get_grad_to_none`: This method should be overridden to set the gradients to your argument for
37+
the operator when needed.
38+
39+
### Example
40+
41+
```
42+
from torchbenchmark.util.benchmark_registry import register_benchmark
43+
import triton
44+
class Model(torch.nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.linear = torch.nn.Linear(10, 10)
48+
49+
def forward(self, x):
50+
return self.linear(x)
51+
52+
@triton.jit
53+
def _kernel(XXX):
54+
# your triton kernel implementation
55+
pass
56+
57+
def kenrel_wrapper(a, b, activation=""):
58+
M, K = a.shape
59+
K, N = b.shape
60+
# Allocates output.
61+
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
62+
# 1D launch kernel where each block gets its own program.
63+
grid = lambda META: (
64+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
65+
)
66+
_kernel[grid](XXX)
67+
return c
68+
69+
class Operator(BenchmarkOperator):
70+
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
71+
super().__init__(tb_args, extra_args)
72+
self.model = Model()
73+
74+
def get_input_iter(self) -> Generator:
75+
for i in range(10):
76+
yield torch.randn(10)
77+
78+
@register_benchmark(baseline=True)
79+
def my_operator(self, input) -> Callable:
80+
return lambda: self.model(input)
81+
82+
@register_benchmark()
83+
def my_operator2(self, input) -> Callable:
84+
return lambda: kernel_wrapper(input)
85+
```

0 commit comments

Comments
 (0)