- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.3k
[Activation-checkpointing] add tensor dedup and param offloading #4247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
c2f840d
              d4bf577
              d86e825
              f70a613
              6dd3fc6
              3fab796
              14bb4c5
              a463d40
              3a29eb8
              254a34d
              4660b1c
              e3f0d3c
              fa35f8d
              fcd0e12
              ecadf3d
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -30,10 +30,31 @@ | |
| if is_torch_npu_available(): | ||
| import torch_npu # noqa: F401 | ||
|  | ||
| # Try to import DTensor for FSDP v2 support | ||
| try: | ||
| from torch.distributed._tensor import DTensor | ||
| except (ImportError, AttributeError): | ||
| DTensor = None | ||
|  | ||
| logger = logging.get_logger(__name__) | ||
|  | ||
|  | ||
| def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple: | ||
| """ | ||
| Get a unique key for a tensor based on its storage pointer and dtype. This allows deduplication of tensors that | ||
| share the same underlying storage. From: | ||
| https://github.com/volcengine/verl/blob/main/verl/utils/activation_offload.py | ||
|  | ||
| Args: | ||
| tensor: The tensor to get the key for | ||
|  | ||
| Returns: | ||
| A tuple of (storage_pointer, dtype) that uniquely identifies the tensor's storage | ||
| """ | ||
| storage_ptr = tensor.untyped_storage().data_ptr() + tensor.storage_offset() | ||
|          | ||
| return (storage_ptr, tensor.dtype) | ||
|  | ||
|  | ||
| class OffloadActivations(saved_tensors_hooks): | ||
| """ | ||
| Context manager under which activation tensors created in the forward pass will be offloaded. | ||
|  | @@ -91,6 +112,12 @@ def __init__( | |
| self.is_first_backward_call = True | ||
| self.is_first_forward_pass = True | ||
|  | ||
| # Storage deduplication: maps storage key to tensor_id to avoid offloading same storage multiple times | ||
| self.storage_to_tensor_id = {} | ||
|  | ||
| # Parameter filtering: track parameter storage pointers to skip them during offloading | ||
| self.param_storages = set() | ||
|  | ||
| # Managing cpu memory | ||
| self.use_pin_memory = use_pin_memory | ||
| self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory | ||
|  | @@ -152,60 +179,82 @@ def pack_tensor(activation: torch.Tensor) -> int: | |
| # set training phase trackers | ||
| self.is_first_forward_call = False | ||
| self.is_first_backward_call = True | ||
| # Reset deduplication map for new forward pass | ||
| self.storage_to_tensor_id = {} | ||
|  | ||
| # query for basic tensor info | ||
| num_bytes = get_num_bytes_tensor(activation) | ||
| tensor_id = get_tensor_id() | ||
|  | ||
| # only offload hefty bois if they're activations on CUDA (our heuristic | ||
| # for that is to check if they're not params or buffers)! | ||
| if ( | ||
| activation.device.type in ["cuda", "xpu", "npu"] | ||
| and num_bytes >= self.min_tensor_size_bytes | ||
| and ( | ||
| not isinstance(activation, torch.nn.Parameter) | ||
| and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)) | ||
| ) | ||
| # Check for tensor deduplication using storage pointer | ||
| # If this storage is already being tracked, we still create a new tensor_id | ||
| # but don't offload again (just keep the tensor in GPU) | ||
| storage_key = _get_unique_tensor_key(activation) | ||
| if storage_key in self.storage_to_tensor_id: | ||
| # Storage already offloaded - don't offload again, just track the reference | ||
| self.tracker[tensor_id] = (activation, False) # Keep on GPU, don't offload | ||
| return tensor_id | ||
|  | ||
| # Check if tensor is on CPU (skip offloading) | ||
| if activation.device.type not in ["cuda", "xpu", "npu"]: | ||
| self.tracker[tensor_id] = (activation, False) | ||
| return tensor_id | ||
|  | ||
| # Check if tensor is too small | ||
| if num_bytes < self.min_tensor_size_bytes: | ||
| self.tracker[tensor_id] = (activation, False) | ||
| return tensor_id | ||
|  | ||
| # Check if tensor is a parameter or buffer | ||
| if isinstance(activation, torch.nn.Parameter) or ( | ||
| hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question here, is Buffer a recent addition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no buffer has always been there from the start, I can clean this up | ||
| ): | ||
| if self.use_streams: | ||
| # First, sync back and dereference previously offloaded tensors | ||
| # as the offloading should be done sufficiently long ago. | ||
| for id in list(self.fwd_stash.keys()): | ||
| if id <= tensor_id - self.max_fwd_stash_size: | ||
| _, ev = self.fwd_stash[id] | ||
| self.s0.wait_event(ev) | ||
| del self.fwd_stash[id] | ||
| else: | ||
| break | ||
| self.tracker[tensor_id] = (activation, False) | ||
| return tensor_id | ||
|  | ||
| # Check if tensor storage is a model parameter (for FSDP compatibility) | ||
| if activation.untyped_storage().data_ptr() in self.param_storages: | ||
| self.tracker[tensor_id] = (activation, False) | ||
| return tensor_id | ||
|  | ||
| # Tensor qualifies for offloading | ||
| if self.use_streams: | ||
| # First, sync back and dereference previously offloaded tensors | ||
| # as the offloading should be done sufficiently long ago. | ||
| for id in list(self.fwd_stash.keys()): | ||
| if id <= tensor_id - self.max_fwd_stash_size: | ||
| _, ev = self.fwd_stash[id] | ||
| self.s0.wait_event(ev) | ||
| del self.fwd_stash[id] | ||
| else: | ||
| break | ||
|  | ||
| # Sync in, offload, and add an event to sync back later | ||
| self.s1.wait_stream(self.s0) | ||
| # Sync in, offload, and add an event to sync back later | ||
| self.s1.wait_stream(self.s0) | ||
|  | ||
| stream = self.s1 if self.use_streams else self.s0 | ||
| if self.accelerator_type == "xpu": | ||
| stream_ctx = torch.xpu.stream(stream) | ||
| elif self.accelerator_type == "npu": | ||
| stream_ctx = torch.npu.stream(stream) | ||
| else: | ||
| stream_ctx = torch.cuda.stream(stream) | ||
| with stream_ctx: | ||
| cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") | ||
| cpu_tensor.copy_(activation, non_blocking=True) | ||
| self.tracker[tensor_id] = ( | ||
| cpu_tensor, | ||
| True, # True = (in future) modified | ||
| ) | ||
|  | ||
| if self.use_streams: | ||
| event = self.s1.record_event() | ||
|  | ||
| # Stash to keep activation alive til s1 is done | ||
| self.fwd_stash[tensor_id] = (activation, event) | ||
| stream = self.s1 if self.use_streams else self.s0 | ||
| if self.accelerator_type == "xpu": | ||
| stream_ctx = torch.xpu.stream(stream) | ||
| elif self.accelerator_type == "npu": | ||
| stream_ctx = torch.npu.stream(stream) | ||
| else: | ||
| stream_ctx = torch.cuda.stream(stream) | ||
| with stream_ctx: | ||
| cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") | ||
| cpu_tensor.copy_(activation, non_blocking=True) | ||
| self.tracker[tensor_id] = ( | ||
| activation, | ||
| False, | ||
| ) # False = not modified, tensor is as is | ||
| cpu_tensor, | ||
| True, # True = (in future) modified | ||
| ) | ||
|  | ||
| if self.use_streams: | ||
| event = self.s1.record_event() | ||
|  | ||
| # Stash to keep activation alive til s1 is done | ||
| self.fwd_stash[tensor_id] = (activation, event) | ||
|  | ||
| # Track this storage for deduplication | ||
| self.storage_to_tensor_id[storage_key] = tensor_id | ||
|  | ||
| return tensor_id | ||
|  | ||
|  | @@ -368,6 +417,36 @@ def hook(outputs, inputs): | |
| unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream | ||
| super().__init__(pack_tensor, unpack_tensor) | ||
|  | ||
| def update_model_params(self, model: nn.Module): | ||
| """ | ||
| Update the set of parameter storage pointers from the model. This allows filtering out model parameters during | ||
| offloading, which is especially important for FSDP models where parameters may not be detected by isinstance | ||
| checks. | ||
|  | ||
| For FSDP v2, this method handles DTensor parameters which may be sharded across ranks and not have valid local | ||
| storage on all ranks. We extract the local tensor from DTensors using _local_tensor when available. | ||
|  | ||
| Args: | ||
| model: The model whose parameters should be tracked | ||
| """ | ||
| param_storages = set() | ||
|  | ||
| for p in model.parameters(): | ||
| # For FSDP v2: extract local tensor from DTensor | ||
| actual_tensor = p | ||
| if DTensor is not None and isinstance(p, DTensor) and hasattr(p, "_local_tensor"): | ||
| actual_tensor = p._local_tensor | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, something to care for. If fp8 is used, it can return 0, viz here: https://github.com/huggingface/accelerate/blob/f0313a64a2f3de359924c85a98ee010c47b846ec/src/accelerate/accelerator.py#L3842 | ||
|  | ||
| # Try to get storage pointer | ||
| try: | ||
| storage_ptr = actual_tensor.untyped_storage().data_ptr() | ||
| param_storages.add(storage_ptr) | ||
| except RuntimeError: | ||
| # Parameter doesn't have accessible storage (e.g., FSDP v2 sharded without local shard) | ||
| continue | ||
|  | ||
| self.param_storages = param_storages | ||
|  | ||
|  | ||
| class NoOpManager(saved_tensors_hooks): | ||
| """ | ||
|  | @@ -433,6 +512,9 @@ def get_act_offloading_ctx_manager( | |
| max_fwd_stash_size=max_fwd_stash_size, | ||
| ) | ||
|  | ||
| # Update parameter storages to filter them during offloading (important for FSDP) | ||
| activations_handling_ctx.update_model_params(model) | ||
|  | ||
| # Below is our hack to disable offloading the last output Linear in every | ||
| # step, as the cost for offloading the activation and then soon after bringing | ||
| # it back is expensive. | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know in which version DTensor was introduced? I'm wondering is this try/expect is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference we have the following in accelerate:
but we also need to check for
torch.distributed.is_available(), otherwise you might get import issue.