-
Notifications
You must be signed in to change notification settings - Fork 294
[NF4] Support nf4 tensor shard and gather #2449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2449
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 2b2f768 with merge base 8b57afe ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Presumably we should observe a comms time reduction? would be nice to see some profiles |
Are there any baseline I could compare with? |
torchao/dtypes/nf4tensor.py
Outdated
) | ||
updated_attrs.update( | ||
{ | ||
"stride": (nf4tensor.size()[1], 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why hardcode the stride?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
torchao/dtypes/nf4tensor.py
Outdated
) | ||
else: | ||
updated_attrs = {} | ||
if nf4tensor.numel() != nf4tensor.size()[0] * nf4tensor.size()[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to have out of bound access with nf4tensor.size()[1]? in the else branch, could len(size) == 0
?
if input_tensors: | ||
for input_tensor in input_tensors[0]: | ||
if hasattr(input_tensor, attr): | ||
input_attrs.append(getattr(input_tensor, attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when tensor are not evenly divisible? or is there is possibility for uneven sharding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if sharded unevenly(did not go through the split dispatch), we will compare the input and output sizes here
@@ -22,7 +22,44 @@ | |||
c10d_functional = torch.ops.c10d_functional | |||
|
|||
|
|||
NF4_OPS_TABLE: Dict[Any, Any] = {} | |||
def nf4_all_gather_into_tensor(func, *args, **kwargs): | |||
nf4tensor = args[0][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assert len(args) and len(args[0]) before accessing them?
test/dtypes/test_nf4.py
Outdated
|
||
@pytest.mark.skipif( | ||
version.parse(torch.__version__).base_version < "2.4.0", | ||
reason="torch >= 2.4 required", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which api is needed in 2.4? DTensor?
@@ -435,7 +435,7 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): | |||
inner_tensor = getattr(viewed_tensor, attr) | |||
self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) | |||
|
|||
@parametrize("input_size", [(512 * 512,), (512, 512)]) | |||
@parametrize("input_size", [(512 * 512,)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing (512, 512) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor=torch.randn(512,512) and tensor.view(512,512) is now valid after changes at nf4_view, move it to test_tensor_2d_view_valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments. probably need to polish the pr more
perf should be on-par. This BE refactoring upstreams NF4 specific logic from torchtune to DTensor. It creates an example for people to follow to handle tensor subclass + DTensor state dict |
Add nf4_all_gather_into_tensor and scatter_nf4tensor to enable dispatch of
scatter
andall_gather_into_tensor
Add unit test to show that nf4 tensor keeps the same after distribute and gather.