|
| 1 | +## How to add a custom operator benchmark |
| 2 | + |
| 3 | +1. Create a new folder in the `operators` directory. |
| 4 | +2. Add an `operator.py` and `__init__.py` file to the new folder. |
| 5 | +3. Implement the `Operator` class. |
| 6 | +4. Register the operator benchmarks in the `operator.py` file. |
| 7 | + |
| 8 | +### Example |
| 9 | + |
| 10 | +``` |
| 11 | +operators/ |
| 12 | + my_operator/ |
| 13 | + __init__.py |
| 14 | + operator.py |
| 15 | +``` |
| 16 | + |
| 17 | +## `__init__.py` |
| 18 | + |
| 19 | +The `__init__.py` file only needs to import the operator to trigger the registration of the benchmarks. |
| 20 | + |
| 21 | +``` |
| 22 | +from .operator import Operator |
| 23 | +``` |
| 24 | + |
| 25 | +## `operator.py` |
| 26 | + |
| 27 | +The `operator.py` file needs to implement the following: |
| 28 | + |
| 29 | +1. `Operator` class: This class should inherit from `BenchmarkOperator`. |
| 30 | +2. `get_input_iter`: This method should return an iterator of input examples for the |
| 31 | + operator. |
| 32 | +3. `@register_benchmark`: This decorator should be used to register the benchmarks for |
| 33 | + the operator. |
| 34 | +4. `get_bwd_fn`: This method should return a callable that performs the backward pass |
| 35 | + for the operator when needed. |
| 36 | +5. `get_grad_to_none`: This method should be overridden to set the gradients to your argument for |
| 37 | + the operator when needed. |
| 38 | + |
| 39 | +### Example |
| 40 | + |
| 41 | +``` |
| 42 | +from torchbenchmark.util.benchmark_registry import register_benchmark |
| 43 | +import triton |
| 44 | +class Model(torch.nn.Module): |
| 45 | + def __init__(self): |
| 46 | + super().__init__() |
| 47 | + self.linear = torch.nn.Linear(10, 10) |
| 48 | +
|
| 49 | + def forward(self, x): |
| 50 | + return self.linear(x) |
| 51 | +
|
| 52 | +@triton.jit |
| 53 | +def _kernel(XXX): |
| 54 | + # your triton kernel implementation |
| 55 | + pass |
| 56 | +
|
| 57 | +def kenrel_wrapper(a, b, activation=""): |
| 58 | + M, K = a.shape |
| 59 | + K, N = b.shape |
| 60 | + # Allocates output. |
| 61 | + c = torch.empty((M, N), device=a.device, dtype=a.dtype) |
| 62 | + # 1D launch kernel where each block gets its own program. |
| 63 | + grid = lambda META: ( |
| 64 | + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), |
| 65 | + ) |
| 66 | + _kernel[grid](XXX) |
| 67 | + return c |
| 68 | +
|
| 69 | +class Operator(BenchmarkOperator): |
| 70 | + def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): |
| 71 | + super().__init__(tb_args, extra_args) |
| 72 | + self.model = Model() |
| 73 | +
|
| 74 | + def get_input_iter(self) -> Generator: |
| 75 | + for i in range(10): |
| 76 | + yield torch.randn(10) |
| 77 | +
|
| 78 | + @register_benchmark(baseline=True) |
| 79 | + def my_operator(self, input) -> Callable: |
| 80 | + return lambda: self.model(input) |
| 81 | +
|
| 82 | + @register_benchmark() |
| 83 | + def my_operator2(self, input) -> Callable: |
| 84 | + return lambda: kernel_wrapper(input) |
| 85 | +``` |
0 commit comments