-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix #2826: implement gradient checkpoint callbacks #2860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e7f03d
0f48774
ed6fa30
5650c76
28a3571
a66a4cb
8e02e15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers.modeling_layers import GradientCheckpointingLayer | ||
|
|
||
| from peft.import_utils import is_bnb_4bit_available, is_bnb_available | ||
| from peft.tuners.tuners_utils import ( | ||
|
|
@@ -351,13 +352,48 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): | |
| # If adapter_names is passed as an argument, we inject it into the forward arguments. | ||
| adapter_names = kwargs.pop("adapter_names", None) | ||
| alora_offsets = kwargs.pop("alora_offsets", None) | ||
|
|
||
| if adapter_names is None and alora_offsets is None: | ||
| # nothing to do | ||
| yield | ||
| return | ||
| hook_handles = [] | ||
|
|
||
| if alora_offsets is not None: | ||
| for layer in self.modules(): | ||
| for n, layer in self.named_modules(): | ||
| # gradient checkpointing layer are executed concurrently to the 'normal' forward call | ||
| # (in the backward step the gradient checkpointing layer's forward will be executed again). | ||
| # this means that when the gradient checkpointing layer is called, the _enable_peft_forward_hooks | ||
| # context manager is long gone. to be consistent with the normal forward we need to register the pre | ||
| # hooks for this concurrent forward call as well. | ||
| # | ||
| # Note that this will lead to double application of whatever the callbacks do in normal forward. | ||
| # Make sure that whatever change is done, can be applied more than once without harm (idempotency). | ||
| if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing: | ||
|
|
||
| def forward_pre_hook(name, module, inputs, **kwargs): | ||
| for submodule in module.modules(): | ||
| if isinstance(submodule, LoraLayer): | ||
| handle = submodule.register_forward_pre_hook( | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]), | ||
| with_kwargs=True, | ||
| ) | ||
| module._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
|
|
||
| def backward_hook(name, module, *grad_output, **kwargs): | ||
| while module._peft_gradient_checkpointing_forward_hooks: | ||
| module._peft_gradient_checkpointing_forward_hooks.pop().remove() | ||
|
|
||
| if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []): | ||
| raise ValueError( | ||
| "Multiple invocations of PEFT forward hooks before .backward() with enabled gradient " | ||
| "checkpointing. Disable gradient checkpointing or only call forward once per backward." | ||
| ) | ||
| layer._peft_gradient_checkpointing_forward_hooks = [] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check for pre-existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep it as is and once there are more methods that require access we can move to method-specific entries in a dictionary instead of a global list. |
||
| handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets)) | ||
| layer._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
| handle = layer.register_full_backward_hook(partial(backward_hook, n)) | ||
| layer._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
| if isinstance(layer, LoraLayer): | ||
| pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets) | ||
| handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_enable_peft_forward_hooksis becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.