Open
Description
Summary
This pattern is very common and can be implemented generically.
The only times this will change is when we need to spoof our actual size, which is uncommon NJT is the only one I can think of
def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()
# Apply the function to each tensor component
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))
return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None, # outer_size parameter
None, # outer_stride parameter
)