|
17 | 17 | from typing import Optional |
18 | 18 |
|
19 | 19 | # Third Party |
| 20 | +from packaging.version import Version |
20 | 21 | from torch import Tensor |
21 | 22 | import torch |
22 | 23 | import torch.nn.functional as F |
|
29 | 30 | # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 |
30 | 31 |
|
31 | 32 |
|
32 | | -def _scaled_mm_cpu_out( |
33 | | - mat1: Tensor, |
34 | | - mat2: Tensor, |
35 | | - scale1: Tensor, |
36 | | - scale2: Tensor, |
37 | | - bias: Optional[Tensor] = None, |
38 | | - scale_result: Optional[Tensor] = None, |
39 | | - out_dtype: Optional[torch.dtype] = None, |
40 | | - use_fast_accum: bool = False, |
41 | | - *, |
42 | | - out: Optional[Tensor] = None, |
43 | | -) -> Tensor: |
44 | | - if out_dtype is None: |
45 | | - out_dtype = torch.float32 |
46 | | - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) |
47 | | - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) |
48 | | - |
49 | | - if bias is not None: |
50 | | - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) |
51 | | - else: |
52 | | - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) |
53 | | - |
54 | | - if out is not None: |
55 | | - out.copy_(ret) |
56 | | - return out |
57 | | - return ret |
58 | | - |
59 | | - |
60 | | -torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) |
61 | | - |
62 | | - |
63 | | -@torch.library.register_kernel("aten::_scaled_mm", "cpu") |
64 | | -def _scaled_mm_cpu( |
65 | | - mat1: Tensor, |
66 | | - mat2: Tensor, |
67 | | - scale1: Tensor, |
68 | | - scale2: Tensor, |
69 | | - bias: Optional[Tensor] = None, |
70 | | - scale_result: Optional[Tensor] = None, |
71 | | - out_dtype: Optional[torch.dtype] = None, |
72 | | - use_fast_accum: bool = False, |
73 | | -) -> Tensor: |
74 | | - return _scaled_mm_cpu_out( |
75 | | - mat1, |
76 | | - mat2, |
77 | | - scale1, |
78 | | - scale2, |
79 | | - bias, |
80 | | - scale_result, |
81 | | - out_dtype, |
82 | | - use_fast_accum, |
83 | | - out=None, |
| 33 | +if Version(torch.__version__) <= Version("2.7"): |
| 34 | + # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, |
| 35 | + # while for earlier versions we need a custom definition |
| 36 | + def _scaled_mm_cpu_out( |
| 37 | + mat1: Tensor, |
| 38 | + mat2: Tensor, |
| 39 | + scale1: Tensor, |
| 40 | + scale2: Tensor, |
| 41 | + bias: Optional[Tensor] = None, |
| 42 | + scale_result: Optional[Tensor] = None, |
| 43 | + out_dtype: Optional[torch.dtype] = None, |
| 44 | + use_fast_accum: bool = False, |
| 45 | + *, |
| 46 | + out: Optional[Tensor] = None, |
| 47 | + ) -> Tensor: |
| 48 | + if out_dtype is None: |
| 49 | + out_dtype = torch.float32 |
| 50 | + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) |
| 51 | + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) |
| 52 | + |
| 53 | + if bias is not None: |
| 54 | + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) |
| 55 | + else: |
| 56 | + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) |
| 57 | + |
| 58 | + if out is not None: |
| 59 | + out.copy_(ret) |
| 60 | + return out |
| 61 | + return ret |
| 62 | + |
| 63 | + torch.library.register_kernel( |
| 64 | + torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out |
84 | 65 | ) |
85 | 66 |
|
| 67 | + @torch.library.register_kernel("aten::_scaled_mm", "cpu") |
| 68 | + def _scaled_mm_cpu( |
| 69 | + mat1: Tensor, |
| 70 | + mat2: Tensor, |
| 71 | + scale1: Tensor, |
| 72 | + scale2: Tensor, |
| 73 | + bias: Optional[Tensor] = None, |
| 74 | + scale_result: Optional[Tensor] = None, |
| 75 | + out_dtype: Optional[torch.dtype] = None, |
| 76 | + use_fast_accum: bool = False, |
| 77 | + ) -> Tensor: |
| 78 | + return _scaled_mm_cpu_out( |
| 79 | + mat1, |
| 80 | + mat2, |
| 81 | + scale1, |
| 82 | + scale2, |
| 83 | + bias, |
| 84 | + scale_result, |
| 85 | + out_dtype, |
| 86 | + use_fast_accum, |
| 87 | + out=None, |
| 88 | + ) |
| 89 | + |
86 | 90 |
|
87 | 91 | @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) |
88 | 92 | def spyre_scaled_bmm( |
|
0 commit comments