Skip to content

[NF4] Support nf4 tensor shard and gather #2449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

mori360
Copy link

@mori360 mori360 commented Jun 26, 2025

Add nf4_all_gather_into_tensor and scatter_nf4tensor to enable dispatch of scatter and all_gather_into_tensor
Add unit test to show that nf4 tensor keeps the same after distribute and gather.

Copy link

pytorch-bot bot commented Jun 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2449

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2b2f768 with merge base 8b57afe (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2025
@mori360 mori360 marked this pull request as ready for review June 26, 2025 19:30
@mori360 mori360 added the topic: for developers Use this tag if this PR is mainly developer facing label Jun 26, 2025
@mori360 mori360 marked this pull request as draft June 26, 2025 21:06
@msaroufim
Copy link
Member

msaroufim commented Jun 26, 2025

Presumably we should observe a comms time reduction? would be nice to see some profiles

@mori360
Copy link
Author

mori360 commented Jun 26, 2025

Presumably we should observe a comms time reduction? would be nice to see some profiles

Are there any baseline I could compare with?

@drisspg drisspg requested a review from weifengpy June 26, 2025 23:28
)
updated_attrs.update(
{
"stride": (nf4tensor.size()[1], 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why hardcode the stride?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)
else:
updated_attrs = {}
if nf4tensor.numel() != nf4tensor.size()[0] * nf4tensor.size()[1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to have out of bound access with nf4tensor.size()[1]? in the else branch, could len(size) == 0?

if input_tensors:
for input_tensor in input_tensors[0]:
if hasattr(input_tensor, attr):
input_attrs.append(getattr(input_tensor, attr))
Copy link
Contributor

@weifengpy weifengpy Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens when tensor are not evenly divisible? or is there is possibility for uneven sharding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if sharded unevenly(did not go through the split dispatch), we will compare the input and output sizes here

@@ -22,7 +22,44 @@
c10d_functional = torch.ops.c10d_functional


NF4_OPS_TABLE: Dict[Any, Any] = {}
def nf4_all_gather_into_tensor(func, *args, **kwargs):
nf4tensor = args[0][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we assert len(args) and len(args[0]) before accessing them?


@pytest.mark.skipif(
version.parse(torch.__version__).base_version < "2.4.0",
reason="torch >= 2.4 required",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which api is needed in 2.4? DTensor?

@@ -435,7 +435,7 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
@parametrize("input_size", [(512 * 512,)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing (512, 512) ?

Copy link
Author

@mori360 mori360 Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor=torch.randn(512,512) and tensor.view(512,512) is now valid after changes at nf4_view, move it to test_tensor_2d_view_valid

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments. probably need to polish the pr more

@weifengpy
Copy link
Contributor

Presumably we should observe a comms time reduction? would be nice to see some profiles

perf should be on-par. This BE refactoring upstreams NF4 specific logic from torchtune to DTensor. It creates an example for people to follow to handle tensor subclass + DTensor state dict

@mori360 mori360 marked this pull request as ready for review June 29, 2025 00:16
@mori360 mori360 requested a review from weifengpy June 29, 2025 00:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: for developers Use this tag if this PR is mainly developer facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants