Skip to content

Conversation

@kashif
Copy link
Collaborator

@kashif kashif commented Oct 10, 2025

What does this PR do?

  • Prevents redundant offloading when multiple tensor views share the same storage
  • Tracks and filters out model parameters during offloading

@kashif
Copy link
Collaborator Author

kashif commented Oct 10, 2025

@sywangyi would you be kind enough to test this on your hardware and give me some feedback? thank you!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


# Try to import DTensor for FSDP v2 support
try:
from torch.distributed._tensor import DTensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know in which version DTensor was introduced? I'm wondering is this try/expect is needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference we have the following in accelerate:

if is_torch_version(">=", "2.5.0"):
        from torch.distributed.tensor import DTensor
    else:
        # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
        from torch.distributed._tensor import DTensor

but we also need to check for torch.distributed.is_available(), otherwise you might get import issue.


# Check if tensor is a parameter or buffer
if isinstance(activation, torch.nn.Parameter) or (
hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here, is Buffer a recent addition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no buffer has always been there from the start, I can clean this up

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for fsdpv2 part !


# Try to import DTensor for FSDP v2 support
try:
from torch.distributed._tensor import DTensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference we have the following in accelerate:

if is_torch_version(">=", "2.5.0"):
        from torch.distributed.tensor import DTensor
    else:
        # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
        from torch.distributed._tensor import DTensor

but we also need to check for torch.distributed.is_available(), otherwise you might get import issue.

Returns:
A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage
"""
storage_ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using data_ptr() can be a bit tricky with for example TorchAO quantized tensors etc, as those can return 0 for data_ptr(). I don't have a concrete example, just something to be aware of.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# For FSDP v2: extract local tensor from DTensor
actual_tensor = p
if DTensor is not None and isinstance(p, DTensor) and hasattr(p, "_local_tensor"):
actual_tensor = p._local_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sywangyi
Copy link
Contributor

@sywangyi would you be kind enough to test this on your hardware and give me some feedback? thank you!

pytest tests/test_activation_offloading.py::TestActivationOffloading::test_parameter_filtering
pytest tests/test_activation_offloading.py::TestActivationOffloading::test_tensor_deduplication

these two cases pass in intel xpu

@kashif
Copy link
Collaborator Author

kashif commented Oct 11, 2025

@S1ro1 ok i'll just skip FP8 activations

@kashif
Copy link
Collaborator Author

kashif commented Oct 17, 2025

@SunMarc I have added support for broadcast and non-contiguous tensors

Copy link
Member

@sergiopaniego sergiopaniego left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@kashif kashif merged commit e2ab435 into main Oct 21, 2025
10 of 12 checks passed
@kashif kashif deleted the activation-dedup branch October 21, 2025 10:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants