Skip to content

Commit 721e970

Browse files
committed
Add _apply_fn_to_data method to TorchAOBaseTensor base class
- Implements generic pattern for applying functions to tensor components - Uses __tensor_flatten__ and __tensor_unflatten__ pattern - Fixes pytorch#2349
1 parent 3d75039 commit 721e970

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

test/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def __init__(self, data):
4949
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5050
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5151

52+
def test_apply_fn_to_data(self):
53+
self.assertTrue(hasattr(TorchAOBaseTensor, "_apply_fn_to_data"))
54+
self.assertTrue(callable(getattr(TorchAOBaseTensor, "_apply_fn_to_data")))
55+
56+
method = getattr(TorchAOBaseTensor, "_apply_fn_to_data")
57+
self.assertFalse(isinstance(method, classmethod))
58+
self.assertFalse(isinstance(method, staticmethod))
59+
5260

5361
if __name__ == "__main__":
5462
unittest.main()

torchao/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,19 @@ def __tensor_unflatten__(
577577
):
578578
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
579579

580+
def _apply_fn_to_data(self, fn: Callable):
581+
"""Applies a fn to all tensor components stored on this class"""
582+
tensor_names, ctx = self.__tensor_flatten__()
583+
new_tensors = {}
584+
for name in tensor_names:
585+
new_tensors[name] = fn(getattr(self, name))
586+
return self.__class__.__tensor_unflatten__(
587+
new_tensors,
588+
ctx,
589+
None,
590+
None,
591+
)
592+
580593
def __repr__(self):
581594
raise NotImplementedError("Subclasses must implement __repr__")
582595

0 commit comments

Comments
 (0)