Skip to content

Commit bdb12a0

Browse files
unwrap args and kwargs
1 parent ef13402 commit bdb12a0

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _scaled_grouped_mm(
3535
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
3636
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
3737
"""
38+
print("SCALED_GROUPED_MM")
3839
return _Float8GroupedMM.apply(
3940
A,
4041
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
4242

4343
@classmethod
4444
def __torch_dispatch__(cls, func, types, args, kwargs={}):
45+
unwrap = lambda x: x._data if isinstance(x, cls) else x
4546
wrap = lambda x: cls(x) if isinstance(x, torch.Tensor) else x
46-
output = super().__torch_dispatch__(func, types, args, kwargs)
47+
unwrapped_args, unwrapped_kwargs = tree_map(unwrap, (args, kwargs))
48+
output = super().__torch_dispatch__(func, types, unwrapped_args, unwrapped_kwargs)
4749
wrapped_output = tree_map(wrap, output)
4850
return wrapped_output

0 commit comments

Comments
 (0)