Skip to content

Commit ef13402

Browse files
fsdp support in moe training
1 parent 60c583e commit ef13402

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch.utils._pytree import tree_map
23

34
from torchao.prototype.moe_training import _scaled_grouped_mm
45

@@ -16,8 +17,12 @@ class ScaledGroupedMMTensor(torch.Tensor):
1617
def __init__(self, data: torch.Tensor):
1718
self._data = data
1819

20+
def __repr__(self):
21+
return f"ScaledGroupedMMTensor({self._data}, dtype={self._data.dtype}, device={self._data.device})"
22+
1923
@classmethod
2024
def __torch_function__(cls, func, types, args, kwargs={}):
25+
print(func.__name__)
2126
if func.__name__ == cls.grouped_mm_func_name:
2227
# Use torchao scaled grouped mm with dynamic quant for
2328
# "2d x 3d with offsets" case (used for routed experts).
@@ -32,4 +37,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
3237
has_offs = kwargs.get(cls.offs_arg_name) is not None
3338
if A_is_2d and B_is_3d and has_offs:
3439
return _scaled_grouped_mm(*args, **kwargs)
40+
3541
return super().__torch_function__(func, types, args, kwargs)
42+
43+
@classmethod
44+
def __torch_dispatch__(cls, func, types, args, kwargs={}):
45+
wrap = lambda x: cls(x) if isinstance(x, torch.Tensor) else x
46+
output = super().__torch_dispatch__(func, types, args, kwargs)
47+
wrapped_output = tree_map(wrap, output)
48+
return wrapped_output

0 commit comments

Comments
 (0)