From c2f840d919a76ea2e0f307e68073434d0d5a7a77 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 10 Oct 2025 08:46:29 +0000 Subject: [PATCH 1/9] add tensor dedup and param offloading --- tests/test_activation_offloading.py | 46 +++++++++ trl/models/activation_offloading.py | 146 +++++++++++++++++++--------- 2 files changed, 148 insertions(+), 44 deletions(-) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index d1a9ea921f5..847123e1475 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -154,3 +154,49 @@ def test_real_hf_model(self): assert torch.allclose(out1, out2, rtol=1e-5) for g1, g2 in zip(grads1, grads2): assert torch.allclose(g1, g2, rtol=1e-5) + + @require_torch_accelerator + def test_tensor_deduplication(self): + """Test that deduplication works correctly for tensors sharing storage""" + # Create a model that produces tensor views + class ModelWithViews(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(100, 100) + + def forward(self, x): + out = self.linear(x) + # Create views - they share storage but are different tensor objects + view1 = out.view(-1) + view2 = out.transpose(0, 1) + return view1.sum() + view2.sum() + + model = ModelWithViews().to(torch_device) + offload_ctx = OffloadActivations() + offload_ctx.update_model_params(model) + + # Run forward+backward with offloading + x = torch.randn(10, 100, device=torch_device, requires_grad=True) + with offload_ctx: + loss = model(x) + loss.backward() + + # Test passes if no errors occur - deduplication prevents redundant offloading + + @require_torch_accelerator + def test_parameter_filtering(self): + """Test that model parameters are filtered during offloading""" + model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device) + offload_ctx = OffloadActivations() + + # Update parameter storages + offload_ctx.update_model_params(model) + + # Verify parameters were tracked + assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages" + + # Get actual parameter storage pointers + param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()} + + # Verify they match + assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages" diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 911b5258304..dbee56c3051 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -34,6 +34,22 @@ logger = logging.get_logger(__name__) +def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple: + """ + Get a unique key for a tensor based on its storage pointer and dtype. This allows deduplication of tensors that + share the same underlying storage. From: + https://github.com/volcengine/verl/blob/main/verl/utils/activation_offload.py + + Args: + tensor: The tensor to get the key for + + Returns: + A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage + """ + storage_ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset() + return (storage_ptr, tensor.dtype) + + class OffloadActivations(saved_tensors_hooks): """ Context manager under which activation tensors created in the forward pass will be offloaded. @@ -91,6 +107,12 @@ def __init__( self.is_first_backward_call = True self.is_first_forward_pass = True + # Storage deduplication: maps storage key to tensor_id to avoid offloading same storage multiple times + self.storage_to_tensor_id = {} + + # Parameter filtering: track parameter storage pointers to skip them during offloading + self.param_storages = set() + # Managing cpu memory self.use_pin_memory = use_pin_memory self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory @@ -152,60 +174,82 @@ def pack_tensor(activation: torch.Tensor) -> int: # set training phase trackers self.is_first_forward_call = False self.is_first_backward_call = True + # Reset deduplication map for new forward pass + self.storage_to_tensor_id = {} # query for basic tensor info num_bytes = get_num_bytes_tensor(activation) tensor_id = get_tensor_id() - # only offload hefty bois if they're activations on CUDA (our heuristic - # for that is to check if they're not params or buffers)! - if ( - activation.device.type in ["cuda", "xpu", "npu"] - and num_bytes >= self.min_tensor_size_bytes - and ( - not isinstance(activation, torch.nn.Parameter) - and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)) - ) + # Check for tensor deduplication using storage pointer + # If this storage is already being tracked, we still create a new tensor_id + # but don't offload again (just keep the tensor in GPU) + storage_key = _get_unique_tensor_key(activation) + if storage_key in self.storage_to_tensor_id: + # Storage already offloaded - don't offload again, just track the reference + self.tracker[tensor_id] = (activation, False) # Keep on GPU, don't offload + return tensor_id + + # Check if tensor is on CPU (skip offloading) + if activation.device.type not in ["cuda", "xpu", "npu"]: + self.tracker[tensor_id] = (activation, False) + return tensor_id + + # Check if tensor is too small + if num_bytes < self.min_tensor_size_bytes: + self.tracker[tensor_id] = (activation, False) + return tensor_id + + # Check if tensor is a parameter or buffer + if isinstance(activation, torch.nn.Parameter) or ( + hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer) ): - if self.use_streams: - # First, sync back and dereference previously offloaded tensors - # as the offloading should be done sufficiently long ago. - for id in list(self.fwd_stash.keys()): - if id <= tensor_id - self.max_fwd_stash_size: - _, ev = self.fwd_stash[id] - self.s0.wait_event(ev) - del self.fwd_stash[id] - else: - break + self.tracker[tensor_id] = (activation, False) + return tensor_id + + # Check if tensor storage is a model parameter (for FSDP compatibility) + if activation.untyped_storage().data_ptr() in self.param_storages: + self.tracker[tensor_id] = (activation, False) + return tensor_id + + # Tensor qualifies for offloading + if self.use_streams: + # First, sync back and dereference previously offloaded tensors + # as the offloading should be done sufficiently long ago. + for id in list(self.fwd_stash.keys()): + if id <= tensor_id - self.max_fwd_stash_size: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + else: + break - # Sync in, offload, and add an event to sync back later - self.s1.wait_stream(self.s0) + # Sync in, offload, and add an event to sync back later + self.s1.wait_stream(self.s0) - stream = self.s1 if self.use_streams else self.s0 - if self.accelerator_type == "xpu": - stream_ctx = torch.xpu.stream(stream) - elif self.accelerator_type == "npu": - stream_ctx = torch.npu.stream(stream) - else: - stream_ctx = torch.cuda.stream(stream) - with stream_ctx: - cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") - cpu_tensor.copy_(activation, non_blocking=True) - self.tracker[tensor_id] = ( - cpu_tensor, - True, # True = (in future) modified - ) - - if self.use_streams: - event = self.s1.record_event() - - # Stash to keep activation alive til s1 is done - self.fwd_stash[tensor_id] = (activation, event) + stream = self.s1 if self.use_streams else self.s0 + if self.accelerator_type == "xpu": + stream_ctx = torch.xpu.stream(stream) + elif self.accelerator_type == "npu": + stream_ctx = torch.npu.stream(stream) else: + stream_ctx = torch.cuda.stream(stream) + with stream_ctx: + cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") + cpu_tensor.copy_(activation, non_blocking=True) self.tracker[tensor_id] = ( - activation, - False, - ) # False = not modified, tensor is as is + cpu_tensor, + True, # True = (in future) modified + ) + + if self.use_streams: + event = self.s1.record_event() + + # Stash to keep activation alive til s1 is done + self.fwd_stash[tensor_id] = (activation, event) + + # Track this storage for deduplication + self.storage_to_tensor_id[storage_key] = tensor_id return tensor_id @@ -368,6 +412,17 @@ def hook(outputs, inputs): unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream super().__init__(pack_tensor, unpack_tensor) + def update_model_params(self, model: nn.Module): + """ + Update the set of parameter storage pointers from the model. This allows filtering out model parameters during + offloading, which is especially important for FSDP models where parameters may not be detected by isinstance + checks. + + Args: + model: The model whose parameters should be tracked + """ + self.param_storages = {p.data.untyped_storage().data_ptr() for p in model.parameters()} + class NoOpManager(saved_tensors_hooks): """ @@ -433,6 +488,9 @@ def get_act_offloading_ctx_manager( max_fwd_stash_size=max_fwd_stash_size, ) + # Update parameter storages to filter them during offloading (important for FSDP) + activations_handling_ctx.update_model_params(model) + # Below is our hack to disable offloading the last output Linear in every # step, as the cost for offloading the activation and then soon after bringing # it back is expensive. From d4bf577e66149c3c06751edaf8b73b957e99319a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 10 Oct 2025 08:52:43 +0000 Subject: [PATCH 2/9] fix formatting --- tests/test_activation_offloading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 847123e1475..20e474ce1ee 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -158,6 +158,7 @@ def test_real_hf_model(self): @require_torch_accelerator def test_tensor_deduplication(self): """Test that deduplication works correctly for tensors sharing storage""" + # Create a model that produces tensor views class ModelWithViews(nn.Module): def __init__(self): From d86e8252ebd9d67cb2c7986e6a44f084d95edb51 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 10 Oct 2025 09:02:31 +0000 Subject: [PATCH 3/9] check if unique_storages_offloaded < total tensors --- tests/test_activation_offloading.py | 30 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 20e474ce1ee..33a5c2f825e 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -159,7 +159,6 @@ def test_real_hf_model(self): def test_tensor_deduplication(self): """Test that deduplication works correctly for tensors sharing storage""" - # Create a model that produces tensor views class ModelWithViews(nn.Module): def __init__(self): super().__init__() @@ -167,37 +166,44 @@ def __init__(self): def forward(self, x): out = self.linear(x) - # Create views - they share storage but are different tensor objects view1 = out.view(-1) view2 = out.transpose(0, 1) return view1.sum() + view2.sum() model = ModelWithViews().to(torch_device) - offload_ctx = OffloadActivations() + offload_ctx = OffloadActivations(min_offload_size=1) offload_ctx.update_model_params(model) - # Run forward+backward with offloading x = torch.randn(10, 100, device=torch_device, requires_grad=True) with offload_ctx: loss = model(x) - loss.backward() - # Test passes if no errors occur - deduplication prevents redundant offloading + total_tensor_ids = offload_ctx.tensor_id + assert total_tensor_ids > 0, "Should have created tensor IDs" + + # modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated) + deduplicated_count = sum(1 for _, modified in offload_ctx.tracker.values() if not modified) + offloaded_count = sum(1 for _, modified in offload_ctx.tracker.values() if modified) + + assert offloaded_count > 0, "Should have offloaded at least one tensor" + assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)" + + unique_storages_offloaded = len(offload_ctx.storage_to_tensor_id) + assert unique_storages_offloaded < total_tensor_ids, ( + f"Deduplication should result in fewer storages ({unique_storages_offloaded}) " + f"than total tensors ({total_tensor_ids})" + ) + + loss.backward() @require_torch_accelerator def test_parameter_filtering(self): """Test that model parameters are filtered during offloading""" model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device) offload_ctx = OffloadActivations() - - # Update parameter storages offload_ctx.update_model_params(model) - # Verify parameters were tracked assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages" - # Get actual parameter storage pointers param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()} - - # Verify they match assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages" From f70a6134d6d5e65d630e3886b54aa0ffde6ca73a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 10 Oct 2025 09:58:00 +0000 Subject: [PATCH 4/9] fix for FSDP v2 --- trl/models/activation_offloading.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index dbee56c3051..f908ef38884 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -30,6 +30,11 @@ if is_torch_npu_available(): import torch_npu # noqa: F401 +# Try to import DTensor for FSDP v2 support +try: + from torch.distributed._tensor import DTensor +except (ImportError, AttributeError): + DTensor = None logger = logging.get_logger(__name__) @@ -418,10 +423,29 @@ def update_model_params(self, model: nn.Module): offloading, which is especially important for FSDP models where parameters may not be detected by isinstance checks. + For FSDP v2, this method handles DTensor parameters which may be sharded across ranks and not have valid local + storage on all ranks. We extract the local tensor from DTensors using _local_tensor when available. + Args: model: The model whose parameters should be tracked """ - self.param_storages = {p.data.untyped_storage().data_ptr() for p in model.parameters()} + param_storages = set() + + for p in model.parameters(): + # For FSDP v2: extract local tensor from DTensor + actual_tensor = p + if DTensor is not None and isinstance(p, DTensor) and hasattr(p, "_local_tensor"): + actual_tensor = p._local_tensor + + # Try to get storage pointer + try: + storage_ptr = actual_tensor.untyped_storage().data_ptr() + param_storages.add(storage_ptr) + except RuntimeError: + # Parameter doesn't have accessible storage (e.g., FSDP v2 sharded without local shard) + continue + + self.param_storages = param_storages class NoOpManager(saved_tensors_hooks): From 3fab796fbd6ab5b8b623aae44655803b8551c2c6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 11 Oct 2025 20:06:58 +0000 Subject: [PATCH 5/9] ignore fp8 --- trl/models/activation_offloading.py | 60 +++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index f908ef38884..6e4ce432724 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -22,6 +22,7 @@ import psutil import torch from accelerate import logging +from accelerate.utils.versions import is_torch_version from torch import nn from torch.autograd.graph import saved_tensors_hooks from transformers import is_torch_npu_available @@ -30,11 +31,17 @@ if is_torch_npu_available(): import torch_npu # noqa: F401 -# Try to import DTensor for FSDP v2 support -try: - from torch.distributed._tensor import DTensor -except (ImportError, AttributeError): - DTensor = None +# Import DTensor for FSDP v2 support with version-aware import path +DTensor = None +if torch.distributed.is_available(): + try: + if is_torch_version(">=", "2.5.0"): + from torch.distributed.tensor import DTensor + else: + # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor + from torch.distributed._tensor import DTensor + except (ImportError, AttributeError): + DTensor = None logger = logging.get_logger(__name__) @@ -51,8 +58,22 @@ def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple: Returns: A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage """ - storage_ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset() - return (storage_ptr, tensor.dtype) + # Handle special tensor types - primarily for FSDP v2 DTensor + actual_tensor = tensor + + # For DTensor (FSDP v2), extract the local tensor + if DTensor is not None and isinstance(tensor, DTensor) and hasattr(tensor, "_local_tensor"): + actual_tensor = tensor._local_tensor + + # Try to get storage pointer, but fall back to tensor id if not accessible + try: + storage_ptr = actual_tensor.untyped_storage().data_ptr() + actual_tensor.storage_offset() + except (RuntimeError, AttributeError): + # For tensors with invalid storage, use tensor id + # This won't enable deduplication for these tensors, but allows offloading to work + storage_ptr = id(actual_tensor) + + return (storage_ptr, actual_tensor.dtype) class OffloadActivations(saved_tensors_hooks): @@ -212,11 +233,26 @@ def pack_tensor(activation: torch.Tensor) -> int: self.tracker[tensor_id] = (activation, False) return tensor_id - # Check if tensor storage is a model parameter (for FSDP compatibility) - if activation.untyped_storage().data_ptr() in self.param_storages: + # Check if tensor is an FP8 tensor (TorchAO) - skip offloading as they're already compressed + tensor_class_name = type(activation).__name__ + if tensor_class_name in ["Float8TrainingTensor", "ScaledMMConfig", "LinearMMConfig"]: self.tracker[tensor_id] = (activation, False) return tensor_id + # Check if tensor storage is a model parameter (for FSDP compatibility) + try: + # Extract actual tensor for DTensor + check_tensor = activation + if DTensor is not None and isinstance(activation, DTensor) and hasattr(activation, "_local_tensor"): + check_tensor = activation._local_tensor + + if check_tensor.untyped_storage().data_ptr() in self.param_storages: + self.tracker[tensor_id] = (activation, False) + return tensor_id + except (RuntimeError, AttributeError): + # If we can't get data_ptr, skip this check + pass + # Tensor qualifies for offloading if self.use_streams: # First, sync back and dereference previously offloaded tensors @@ -440,9 +476,11 @@ def update_model_params(self, model: nn.Module): # Try to get storage pointer try: storage_ptr = actual_tensor.untyped_storage().data_ptr() - param_storages.add(storage_ptr) + if storage_ptr != 0: + param_storages.add(storage_ptr) except RuntimeError: - # Parameter doesn't have accessible storage (e.g., FSDP v2 sharded without local shard) + # Parameter doesn't have accessible storage (e.g., FSDP v2 sharded without local shard, FP8 parameters) + # These will be caught by other checks (isinstance for Parameter, class name for FP8) continue self.param_storages = param_storages From 3a29eb8c07d5fb39f696f33b1b588fbe95faeb8e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Oct 2025 11:42:45 +0000 Subject: [PATCH 6/9] checking if events exist before accessing them --- trl/models/activation_offloading.py | 39 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 6e4ce432724..383fced095d 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -328,8 +328,9 @@ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: def wait_and_del_remaining_references() -> None: for id in list(self.bwd_tensor_stash.keys()): - event = self.bwd_ev_stash[id] - self.s1.wait_event(event) + if id in self.bwd_ev_stash: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) del self.bwd_tensor_stash[id] # Register a callback to the end of autograd to clean everything up @@ -415,17 +416,19 @@ def hook(outputs, inputs): # compute stream (s0 here). Note that the con here is we introduce # non-deterministic (thus higher) memory usage, but this case # should not happen often. - unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] - if self.accelerator_type == "npu": - storage_count = torch_npu._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) - else: - storage_count = torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) - if storage_count > storage_refcount: - unpacked_tensor.record_stream(self.s0) - del self.bwd_tensor_stash[unpack_tensor_id] - else: - event = self.s0.record_event() - self.bwd_ev_stash[unpack_tensor_id] = event + # Check if tensor still exists (might have been cleaned up by a previous node) + if unpack_tensor_id in self.bwd_tensor_stash: + unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] + if self.accelerator_type == "npu": + storage_count = torch_npu._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) + else: + storage_count = torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) + if storage_count > storage_refcount: + unpacked_tensor.record_stream(self.s0) + del self.bwd_tensor_stash[unpack_tensor_id] + else: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event # if there are still things in the fwd_stash, get rid of them as we're in bwd now for id in list(self.fwd_stash.keys()): @@ -435,9 +438,13 @@ def hook(outputs, inputs): # wait on prev node's events and del those for id in prev_node_ids: - event = self.bwd_ev_stash[id] - self.s1.wait_event(event) - del self.bwd_tensor_stash[id] + # Only wait on events that exist (some tensors may have used record_stream instead) + if id in self.bwd_ev_stash: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_ev_stash[id] + if id in self.bwd_tensor_stash: + del self.bwd_tensor_stash[id] return outputs From 254a34d8e62da8ea3ce602a4564cecc8027a042e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Oct 2025 14:10:07 +0000 Subject: [PATCH 7/9] preserve stride information --- trl/models/activation_offloading.py | 50 +++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 383fced095d..7fabf52d667 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -213,30 +213,30 @@ def pack_tensor(activation: torch.Tensor) -> int: storage_key = _get_unique_tensor_key(activation) if storage_key in self.storage_to_tensor_id: # Storage already offloaded - don't offload again, just track the reference - self.tracker[tensor_id] = (activation, False) # Keep on GPU, don't offload + self.tracker[tensor_id] = (activation, False, None, None) # Keep on GPU, don't offload return tensor_id # Check if tensor is on CPU (skip offloading) if activation.device.type not in ["cuda", "xpu", "npu"]: - self.tracker[tensor_id] = (activation, False) + self.tracker[tensor_id] = (activation, False, None, None) return tensor_id # Check if tensor is too small if num_bytes < self.min_tensor_size_bytes: - self.tracker[tensor_id] = (activation, False) + self.tracker[tensor_id] = (activation, False, None, None) return tensor_id # Check if tensor is a parameter or buffer if isinstance(activation, torch.nn.Parameter) or ( hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer) ): - self.tracker[tensor_id] = (activation, False) + self.tracker[tensor_id] = (activation, False, None, None) return tensor_id # Check if tensor is an FP8 tensor (TorchAO) - skip offloading as they're already compressed tensor_class_name = type(activation).__name__ if tensor_class_name in ["Float8TrainingTensor", "ScaledMMConfig", "LinearMMConfig"]: - self.tracker[tensor_id] = (activation, False) + self.tracker[tensor_id] = (activation, False, None, None) return tensor_id # Check if tensor storage is a model parameter (for FSDP compatibility) @@ -247,7 +247,7 @@ def pack_tensor(activation: torch.Tensor) -> int: check_tensor = activation._local_tensor if check_tensor.untyped_storage().data_ptr() in self.param_storages: - self.tracker[tensor_id] = (activation, False) + self.tracker[tensor_id] = (activation, False, None, None) return tensor_id except (RuntimeError, AttributeError): # If we can't get data_ptr, skip this check @@ -276,11 +276,19 @@ def pack_tensor(activation: torch.Tensor) -> int: else: stream_ctx = torch.cuda.stream(stream) with stream_ctx: + # Save original stride information to restore during unpack + original_stride = activation.stride() + original_storage_offset = activation.storage_offset() + cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") cpu_tensor.copy_(activation, non_blocking=True) + + # Store CPU tensor along with original stride information self.tracker[tensor_id] = ( cpu_tensor, True, # True = (in future) modified + original_stride, # Save original GPU stride + original_storage_offset, # Save original storage offset ) if self.use_streams: @@ -308,9 +316,20 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: if unpack_tensor_id not in self.tracker: raise ValueError(f"Untracked tensor with id {unpack_tensor_id}") - maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] + maybe_accelerator_tensor, modified, original_stride, original_storage_offset = self.tracker[ + unpack_tensor_id + ] + if modified: accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) + # Restore original stride if we saved it (only for offloaded tensors) + if original_stride is not None: + accelerator_tensor = torch.as_strided( + accelerator_tensor, + size=accelerator_tensor.size(), + stride=original_stride, + storage_offset=original_storage_offset, + ) maybe_accelerator_tensor = accelerator_tensor # clear tensor from tracking @@ -346,7 +365,10 @@ def wait_and_del_remaining_references() -> None: if unpack_tensor_id not in self.tracker: raise ValueError(f"untracked tensor with id {unpack_tensor_id}") - maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] + maybe_accelerator_tensor, modified, original_stride, original_storage_offset = self.tracker[ + unpack_tensor_id + ] + if modified: # Get data on the current autograd node graph_id = torch._C._current_graph_task_id() @@ -372,6 +394,14 @@ def wait_and_del_remaining_references() -> None: stream_ctx = torch.cuda.stream(self.s1) with stream_ctx: accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) + # Restore original stride if we saved it + if original_stride is not None: + accelerator_tensor = torch.as_strided( + accelerator_tensor, + size=accelerator_tensor.size(), + stride=original_stride, + storage_offset=original_storage_offset, + ) maybe_accelerator_tensor = accelerator_tensor # Tell comp stream to wait for the info to be loaded before executing @@ -420,7 +450,9 @@ def hook(outputs, inputs): if unpack_tensor_id in self.bwd_tensor_stash: unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] if self.accelerator_type == "npu": - storage_count = torch_npu._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) + storage_count = torch_npu._C._storage_Use_Count( + unpacked_tensor.untyped_storage()._cdata + ) else: storage_count = torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) if storage_count > storage_refcount: From 4660b1c20bf5198186591c231ee0cfb3e3cc5dcb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 17 Oct 2025 14:31:51 +0000 Subject: [PATCH 8/9] handle both broadcast and non-broadcast cases! --- trl/models/activation_offloading.py | 76 +++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 7fabf52d667..a8596128e15 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -213,30 +213,30 @@ def pack_tensor(activation: torch.Tensor) -> int: storage_key = _get_unique_tensor_key(activation) if storage_key in self.storage_to_tensor_id: # Storage already offloaded - don't offload again, just track the reference - self.tracker[tensor_id] = (activation, False, None, None) # Keep on GPU, don't offload + self.tracker[tensor_id] = (activation, False, None, None, None) # Keep on GPU, don't offload return tensor_id # Check if tensor is on CPU (skip offloading) if activation.device.type not in ["cuda", "xpu", "npu"]: - self.tracker[tensor_id] = (activation, False, None, None) + self.tracker[tensor_id] = (activation, False, None, None, None) return tensor_id # Check if tensor is too small if num_bytes < self.min_tensor_size_bytes: - self.tracker[tensor_id] = (activation, False, None, None) + self.tracker[tensor_id] = (activation, False, None, None, None) return tensor_id # Check if tensor is a parameter or buffer if isinstance(activation, torch.nn.Parameter) or ( hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer) ): - self.tracker[tensor_id] = (activation, False, None, None) + self.tracker[tensor_id] = (activation, False, None, None, None) return tensor_id # Check if tensor is an FP8 tensor (TorchAO) - skip offloading as they're already compressed tensor_class_name = type(activation).__name__ if tensor_class_name in ["Float8TrainingTensor", "ScaledMMConfig", "LinearMMConfig"]: - self.tracker[tensor_id] = (activation, False, None, None) + self.tracker[tensor_id] = (activation, False, None, None, None) return tensor_id # Check if tensor storage is a model parameter (for FSDP compatibility) @@ -247,7 +247,7 @@ def pack_tensor(activation: torch.Tensor) -> int: check_tensor = activation._local_tensor if check_tensor.untyped_storage().data_ptr() in self.param_storages: - self.tracker[tensor_id] = (activation, False, None, None) + self.tracker[tensor_id] = (activation, False, None, None, None) return tensor_id except (RuntimeError, AttributeError): # If we can't get data_ptr, skip this check @@ -276,19 +276,43 @@ def pack_tensor(activation: torch.Tensor) -> int: else: stream_ctx = torch.cuda.stream(stream) with stream_ctx: - # Save original stride information to restore during unpack + # Save original stride and shape information original_stride = activation.stride() original_storage_offset = activation.storage_offset() + original_shape = activation.size() + + # Check if tensor has broadcast dimensions (stride == 0) + # If so, copy the underlying storage directly instead of materializing the broadcast + has_broadcast = 0 in original_stride + + if has_broadcast: + # Copy only the actual underlying storage, not the materialized broadcast + # Create CPU tensor with same storage size as original + storage_size = activation.untyped_storage().size() + cpu_storage = torch.empty( + storage_size // activation.element_size(), + dtype=activation.dtype, + pin_memory=self.use_pin_memory, + device="cpu", + ) + # Copy the raw storage + cpu_storage_view = torch.as_strided( + activation, size=(storage_size // activation.element_size(),), stride=(1,), storage_offset=0 + ) + cpu_storage.copy_(cpu_storage_view, non_blocking=True) + cpu_tensor = cpu_storage + else: + # No broadcast - use normal contiguous copy + cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") + cpu_tensor.copy_(activation, non_blocking=True) - cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") - cpu_tensor.copy_(activation, non_blocking=True) - - # Store CPU tensor along with original stride information + # Store CPU tensor along with stride information self.tracker[tensor_id] = ( cpu_tensor, True, # True = (in future) modified original_stride, # Save original GPU stride original_storage_offset, # Save original storage offset + original_shape, # Save original shape for broadcast restoration ) if self.use_streams: @@ -316,17 +340,22 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: if unpack_tensor_id not in self.tracker: raise ValueError(f"Untracked tensor with id {unpack_tensor_id}") - maybe_accelerator_tensor, modified, original_stride, original_storage_offset = self.tracker[ - unpack_tensor_id - ] + ( + maybe_accelerator_tensor, + modified, + original_stride, + original_storage_offset, + original_shape, + ) = self.tracker[unpack_tensor_id] if modified: + # Restore tensor to GPU accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) - # Restore original stride if we saved it (only for offloaded tensors) + # Restore original stride if we saved it (handles both broadcast and non-broadcast cases) if original_stride is not None: accelerator_tensor = torch.as_strided( accelerator_tensor, - size=accelerator_tensor.size(), + size=original_shape, stride=original_stride, storage_offset=original_storage_offset, ) @@ -365,9 +394,13 @@ def wait_and_del_remaining_references() -> None: if unpack_tensor_id not in self.tracker: raise ValueError(f"untracked tensor with id {unpack_tensor_id}") - maybe_accelerator_tensor, modified, original_stride, original_storage_offset = self.tracker[ - unpack_tensor_id - ] + ( + maybe_accelerator_tensor, + modified, + original_stride, + original_storage_offset, + original_shape, + ) = self.tracker[unpack_tensor_id] if modified: # Get data on the current autograd node @@ -393,12 +426,13 @@ def wait_and_del_remaining_references() -> None: else: stream_ctx = torch.cuda.stream(self.s1) with stream_ctx: + # Restore tensor to GPU accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) - # Restore original stride if we saved it + # Restore original stride if we saved it (handles both broadcast and non-broadcast cases) if original_stride is not None: accelerator_tensor = torch.as_strided( accelerator_tensor, - size=accelerator_tensor.size(), + size=original_shape, stride=original_stride, storage_offset=original_storage_offset, ) From fcd0e127746364cbcd621bd67390394b031fc950 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 09:52:22 +0000 Subject: [PATCH 9/9] fix test --- tests/test_activation_offloading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 33a5c2f825e..769ef5e3b2b 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -182,8 +182,8 @@ def forward(self, x): assert total_tensor_ids > 0, "Should have created tensor IDs" # modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated) - deduplicated_count = sum(1 for _, modified in offload_ctx.tracker.values() if not modified) - offloaded_count = sum(1 for _, modified in offload_ctx.tracker.values() if modified) + deduplicated_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if not modified) + offloaded_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if modified) assert offloaded_count > 0, "Should have offloaded at least one tensor" assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)"