1
1
import torch
2
+ from torch .utils ._pytree import tree_map
2
3
3
4
from torchao .prototype .moe_training import _scaled_grouped_mm
4
5
@@ -16,8 +17,12 @@ class ScaledGroupedMMTensor(torch.Tensor):
16
17
def __init__ (self , data : torch .Tensor ):
17
18
self ._data = data
18
19
20
+ def __repr__ (self ):
21
+ return f"ScaledGroupedMMTensor({ self ._data } , dtype={ self ._data .dtype } , device={ self ._data .device } )"
22
+
19
23
@classmethod
20
24
def __torch_function__ (cls , func , types , args , kwargs = {}):
25
+ print (func .__name__ )
21
26
if func .__name__ == cls .grouped_mm_func_name :
22
27
# Use torchao scaled grouped mm with dynamic quant for
23
28
# "2d x 3d with offsets" case (used for routed experts).
@@ -32,4 +37,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
32
37
has_offs = kwargs .get (cls .offs_arg_name ) is not None
33
38
if A_is_2d and B_is_3d and has_offs :
34
39
return _scaled_grouped_mm (* args , ** kwargs )
40
+
35
41
return super ().__torch_function__ (func , types , args , kwargs )
42
+
43
+ @classmethod
44
+ def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
45
+ wrap = lambda x : cls (x ) if isinstance (x , torch .Tensor ) else x
46
+ output = super ().__torch_dispatch__ (func , types , args , kwargs )
47
+ wrapped_output = tree_map (wrap , output )
48
+ return wrapped_output
0 commit comments