Skip to content

Commit 0f76a0d

Browse files
Merge pull request #187 from foundation-model-stack/scaled_mm_out
fix: remove custom scaled bmm op on cpu and fix fp8 test
2 parents 47b8716 + 1c4f22e commit 0f76a0d

File tree

3 files changed

+62
-62
lines changed

3 files changed

+62
-62
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional
1818

1919
# Third Party
20+
from packaging.version import Version
2021
from torch import Tensor
2122
import torch
2223
import torch.nn.functional as F
@@ -29,60 +30,63 @@
2930
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3031

3132

32-
def _scaled_mm_cpu_out(
33-
mat1: Tensor,
34-
mat2: Tensor,
35-
scale1: Tensor,
36-
scale2: Tensor,
37-
bias: Optional[Tensor] = None,
38-
scale_result: Optional[Tensor] = None,
39-
out_dtype: Optional[torch.dtype] = None,
40-
use_fast_accum: bool = False,
41-
*,
42-
out: Optional[Tensor] = None,
43-
) -> Tensor:
44-
if out_dtype is None:
45-
out_dtype = torch.float32
46-
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
47-
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
48-
49-
if bias is not None:
50-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
51-
else:
52-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
53-
54-
if out is not None:
55-
out.copy_(ret)
56-
return out
57-
return ret
58-
59-
60-
torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)
61-
62-
63-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
64-
def _scaled_mm_cpu(
65-
mat1: Tensor,
66-
mat2: Tensor,
67-
scale1: Tensor,
68-
scale2: Tensor,
69-
bias: Optional[Tensor] = None,
70-
scale_result: Optional[Tensor] = None,
71-
out_dtype: Optional[torch.dtype] = None,
72-
use_fast_accum: bool = False,
73-
) -> Tensor:
74-
return _scaled_mm_cpu_out(
75-
mat1,
76-
mat2,
77-
scale1,
78-
scale2,
79-
bias,
80-
scale_result,
81-
out_dtype,
82-
use_fast_accum,
83-
out=None,
33+
if Version(torch.__version__) <= Version("2.7"):
34+
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35+
# while for earlier versions we need a custom definition
36+
def _scaled_mm_cpu_out(
37+
mat1: Tensor,
38+
mat2: Tensor,
39+
scale1: Tensor,
40+
scale2: Tensor,
41+
bias: Optional[Tensor] = None,
42+
scale_result: Optional[Tensor] = None,
43+
out_dtype: Optional[torch.dtype] = None,
44+
use_fast_accum: bool = False,
45+
*,
46+
out: Optional[Tensor] = None,
47+
) -> Tensor:
48+
if out_dtype is None:
49+
out_dtype = torch.float32
50+
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
51+
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
52+
53+
if bias is not None:
54+
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
55+
else:
56+
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
57+
58+
if out is not None:
59+
out.copy_(ret)
60+
return out
61+
return ret
62+
63+
torch.library.register_kernel(
64+
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
8465
)
8566

67+
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
68+
def _scaled_mm_cpu(
69+
mat1: Tensor,
70+
mat2: Tensor,
71+
scale1: Tensor,
72+
scale2: Tensor,
73+
bias: Optional[Tensor] = None,
74+
scale_result: Optional[Tensor] = None,
75+
out_dtype: Optional[torch.dtype] = None,
76+
use_fast_accum: bool = False,
77+
) -> Tensor:
78+
return _scaled_mm_cpu_out(
79+
mat1,
80+
mat2,
81+
scale1,
82+
scale2,
83+
bias,
84+
scale_result,
85+
out_dtype,
86+
use_fast_accum,
87+
out=None,
88+
)
89+
8690

8791
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
8892
def spyre_scaled_bmm(

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ def implement_op_decorator(op_namespace_id):
3636
Always compare against pytorch version in current environment.
3737
"""
3838

39-
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
40-
4139
def decorator(func):
42-
if torch_version < Version("2.4"):
40+
if Version(torch.__version__) < Version("2.4"):
4341
return torch.library.impl(op_namespace_id, "default")(func)
4442
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
4543

@@ -51,10 +49,8 @@ def register_op_decorator(op_namespace_id):
5149
Always compare against pytorch version in current environment.
5250
"""
5351

54-
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
55-
5652
def decorator(func):
57-
if torch_version < Version("2.4"):
53+
if Version(torch.__version__) < Version("2.4"):
5854
return torch.library.impl_abstract(op_namespace_id)(func)
5955
return torch.library.register_fake(op_namespace_id)(func)
6056

@@ -73,7 +69,7 @@ def register_aiu_i8i8_op():
7369
logger.warning("AIU op has already been registered")
7470
return
7571
op_namespace_id = "fms_mo::i8i8_aiu"
76-
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
72+
if Version(torch.__version__) < Version("2.4"):
7773
torch.library.define(
7874
op_namespace_id,
7975
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "

tests/aiu_addons/test_fp8_addon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def test_fp8_op() -> None:
5151
# Local
5252
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op
5353

54-
query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
55-
key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
56-
value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda")
54+
query = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
55+
key = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
56+
value = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda")
5757

5858
out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
5959
assert out.size() == query.size()

0 commit comments

Comments
 (0)