Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,56 @@ def test_real_hf_model(self):
assert torch.allclose(out1, out2, rtol=1e-5)
for g1, g2 in zip(grads1, grads2):
assert torch.allclose(g1, g2, rtol=1e-5)

@require_torch_accelerator
def test_tensor_deduplication(self):
"""Test that deduplication works correctly for tensors sharing storage"""

class ModelWithViews(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 100)

def forward(self, x):
out = self.linear(x)
view1 = out.view(-1)
view2 = out.transpose(0, 1)
return view1.sum() + view2.sum()

model = ModelWithViews().to(torch_device)
offload_ctx = OffloadActivations(min_offload_size=1)
offload_ctx.update_model_params(model)

x = torch.randn(10, 100, device=torch_device, requires_grad=True)
with offload_ctx:
loss = model(x)

total_tensor_ids = offload_ctx.tensor_id
assert total_tensor_ids > 0, "Should have created tensor IDs"

# modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated)
deduplicated_count = sum(1 for _, modified in offload_ctx.tracker.values() if not modified)
offloaded_count = sum(1 for _, modified in offload_ctx.tracker.values() if modified)

assert offloaded_count > 0, "Should have offloaded at least one tensor"
assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)"

unique_storages_offloaded = len(offload_ctx.storage_to_tensor_id)
assert unique_storages_offloaded < total_tensor_ids, (
f"Deduplication should result in fewer storages ({unique_storages_offloaded}) "
f"than total tensors ({total_tensor_ids})"
)

loss.backward()

@require_torch_accelerator
def test_parameter_filtering(self):
"""Test that model parameters are filtered during offloading"""
model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device)
offload_ctx = OffloadActivations()
offload_ctx.update_model_params(model)

assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages"

param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()}
assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages"
170 changes: 126 additions & 44 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,31 @@
if is_torch_npu_available():
import torch_npu # noqa: F401

# Try to import DTensor for FSDP v2 support
try:
from torch.distributed._tensor import DTensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know in which version DTensor was introduced? I'm wondering is this try/expect is needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference we have the following in accelerate:

if is_torch_version(">=", "2.5.0"):
        from torch.distributed.tensor import DTensor
    else:
        # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
        from torch.distributed._tensor import DTensor

but we also need to check for torch.distributed.is_available(), otherwise you might get import issue.

except (ImportError, AttributeError):
DTensor = None

logger = logging.get_logger(__name__)


def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple:
"""
Get a unique key for a tensor based on its storage pointer and dtype. This allows deduplication of tensors that
share the same underlying storage. From:
https://github.com/volcengine/verl/blob/main/verl/utils/activation_offload.py

Args:
tensor: The tensor to get the key for

Returns:
A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage
"""
storage_ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using data_ptr() can be a bit tricky with for example TorchAO quantized tensors etc, as those can return 0 for data_ptr(). I don't have a concrete example, just something to be aware of.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return (storage_ptr, tensor.dtype)


class OffloadActivations(saved_tensors_hooks):
"""
Context manager under which activation tensors created in the forward pass will be offloaded.
Expand Down Expand Up @@ -91,6 +112,12 @@ def __init__(
self.is_first_backward_call = True
self.is_first_forward_pass = True

# Storage deduplication: maps storage key to tensor_id to avoid offloading same storage multiple times
self.storage_to_tensor_id = {}

# Parameter filtering: track parameter storage pointers to skip them during offloading
self.param_storages = set()

# Managing cpu memory
self.use_pin_memory = use_pin_memory
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
Expand Down Expand Up @@ -152,60 +179,82 @@ def pack_tensor(activation: torch.Tensor) -> int:
# set training phase trackers
self.is_first_forward_call = False
self.is_first_backward_call = True
# Reset deduplication map for new forward pass
self.storage_to_tensor_id = {}

# query for basic tensor info
num_bytes = get_num_bytes_tensor(activation)
tensor_id = get_tensor_id()

# only offload hefty bois if they're activations on CUDA (our heuristic
# for that is to check if they're not params or buffers)!
if (
activation.device.type in ["cuda", "xpu", "npu"]
and num_bytes >= self.min_tensor_size_bytes
and (
not isinstance(activation, torch.nn.Parameter)
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
)
# Check for tensor deduplication using storage pointer
# If this storage is already being tracked, we still create a new tensor_id
# but don't offload again (just keep the tensor in GPU)
storage_key = _get_unique_tensor_key(activation)
if storage_key in self.storage_to_tensor_id:
# Storage already offloaded - don't offload again, just track the reference
self.tracker[tensor_id] = (activation, False) # Keep on GPU, don't offload
return tensor_id

# Check if tensor is on CPU (skip offloading)
if activation.device.type not in ["cuda", "xpu", "npu"]:
self.tracker[tensor_id] = (activation, False)
return tensor_id

# Check if tensor is too small
if num_bytes < self.min_tensor_size_bytes:
self.tracker[tensor_id] = (activation, False)
return tensor_id

# Check if tensor is a parameter or buffer
if isinstance(activation, torch.nn.Parameter) or (
hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here, is Buffer a recent addition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no buffer has always been there from the start, I can clean this up

):
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for id in list(self.fwd_stash.keys()):
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
else:
break
self.tracker[tensor_id] = (activation, False)
return tensor_id

# Check if tensor storage is a model parameter (for FSDP compatibility)
if activation.untyped_storage().data_ptr() in self.param_storages:
self.tracker[tensor_id] = (activation, False)
return tensor_id

# Tensor qualifies for offloading
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for id in list(self.fwd_stash.keys()):
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
else:
break

# Sync in, offload, and add an event to sync back later
self.s1.wait_stream(self.s0)
# Sync in, offload, and add an event to sync back later
self.s1.wait_stream(self.s0)

stream = self.s1 if self.use_streams else self.s0
if self.accelerator_type == "xpu":
stream_ctx = torch.xpu.stream(stream)
elif self.accelerator_type == "npu":
stream_ctx = torch.npu.stream(stream)
else:
stream_ctx = torch.cuda.stream(stream)
with stream_ctx:
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
cpu_tensor.copy_(activation, non_blocking=True)
self.tracker[tensor_id] = (
cpu_tensor,
True, # True = (in future) modified
)

if self.use_streams:
event = self.s1.record_event()

# Stash to keep activation alive til s1 is done
self.fwd_stash[tensor_id] = (activation, event)
stream = self.s1 if self.use_streams else self.s0
if self.accelerator_type == "xpu":
stream_ctx = torch.xpu.stream(stream)
elif self.accelerator_type == "npu":
stream_ctx = torch.npu.stream(stream)
else:
stream_ctx = torch.cuda.stream(stream)
with stream_ctx:
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
cpu_tensor.copy_(activation, non_blocking=True)
self.tracker[tensor_id] = (
activation,
False,
) # False = not modified, tensor is as is
cpu_tensor,
True, # True = (in future) modified
)

if self.use_streams:
event = self.s1.record_event()

# Stash to keep activation alive til s1 is done
self.fwd_stash[tensor_id] = (activation, event)

# Track this storage for deduplication
self.storage_to_tensor_id[storage_key] = tensor_id

return tensor_id

Expand Down Expand Up @@ -368,6 +417,36 @@ def hook(outputs, inputs):
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
super().__init__(pack_tensor, unpack_tensor)

def update_model_params(self, model: nn.Module):
"""
Update the set of parameter storage pointers from the model. This allows filtering out model parameters during
offloading, which is especially important for FSDP models where parameters may not be detected by isinstance
checks.

For FSDP v2, this method handles DTensor parameters which may be sharded across ranks and not have valid local
storage on all ranks. We extract the local tensor from DTensors using _local_tensor when available.

Args:
model: The model whose parameters should be tracked
"""
param_storages = set()

for p in model.parameters():
# For FSDP v2: extract local tensor from DTensor
actual_tensor = p
if DTensor is not None and isinstance(p, DTensor) and hasattr(p, "_local_tensor"):
actual_tensor = p._local_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Try to get storage pointer
try:
storage_ptr = actual_tensor.untyped_storage().data_ptr()
param_storages.add(storage_ptr)
except RuntimeError:
# Parameter doesn't have accessible storage (e.g., FSDP v2 sharded without local shard)
continue

self.param_storages = param_storages


class NoOpManager(saved_tensors_hooks):
"""
Expand Down Expand Up @@ -433,6 +512,9 @@ def get_act_offloading_ctx_manager(
max_fwd_stash_size=max_fwd_stash_size,
)

# Update parameter storages to filter them during offloading (important for FSDP)
activations_handling_ctx.update_model_params(model)

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive.
Expand Down
Loading