File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -49,6 +49,14 @@ def __init__(self, data):
49
49
with self .assertRaisesRegex (NotImplementedError , "arg_types" ):
50
50
l .weight = torch .nn .Parameter (MyTensor (l .weight ))
51
51
52
+ def test_apply_fn_to_data (self ):
53
+ self .assertTrue (hasattr (TorchAOBaseTensor , "_apply_fn_to_data" ))
54
+ self .assertTrue (callable (getattr (TorchAOBaseTensor , "_apply_fn_to_data" )))
55
+
56
+ method = getattr (TorchAOBaseTensor , "_apply_fn_to_data" )
57
+ self .assertFalse (isinstance (method , classmethod ))
58
+ self .assertFalse (isinstance (method , staticmethod ))
59
+
52
60
53
61
if __name__ == "__main__" :
54
62
unittest .main ()
Original file line number Diff line number Diff line change @@ -577,6 +577,19 @@ def __tensor_unflatten__(
577
577
):
578
578
raise NotImplementedError ("Subclasses must implement __tensor_unflatten__" )
579
579
580
+ def _apply_fn_to_data (self , fn : Callable ):
581
+ """Applies a fn to all tensor components stored on this class"""
582
+ tensor_names , ctx = self .__tensor_flatten__ ()
583
+ new_tensors = {}
584
+ for name in tensor_names :
585
+ new_tensors [name ] = fn (getattr (self , name ))
586
+ return self .__class__ .__tensor_unflatten__ (
587
+ new_tensors ,
588
+ ctx ,
589
+ None ,
590
+ None ,
591
+ )
592
+
580
593
def __repr__ (self ):
581
594
raise NotImplementedError ("Subclasses must implement __repr__" )
582
595
You can’t perform that action at this time.
0 commit comments