Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
self.base_model.config.pretraining_tp = 1

self._adapters_disabled = False

@property
def peft_config(self) -> dict[str, PeftConfig]:
if self._is_prompt_learning:
Expand All @@ -167,6 +169,17 @@ def active_adapters(self) -> list[str]:
adapters = [adapters]
return adapters

@property
def has_active_enabled_adapter(self) -> bool:
"""Reflects whether the adapters are purposefully disabled (via disable_adapter) or if there
are no active adapters (enabled but inactive). They are two separate mechanisms but sometimes it is helpful to
know whether the model has any active/enabled adapter at all.
"""
if self.peft_config[self.active_adapter].is_prompt_learning:
return not self._adapters_disabled

return not self._adapters_disabled or not self.active_adapters

@peft_config.setter
def peft_config(self, value: dict[str, PeftConfig]):
if self._is_prompt_learning:
Expand Down Expand Up @@ -890,7 +903,7 @@ def __getattr__(self, name: str):
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this
# runs without any changes
if hasattr(self.base_model, "_enable_peft_forward_hooks"):
if hasattr(self.base_model, "_enable_peft_forward_hooks") and self.has_active_enabled_adapter:
with self.base_model._enable_peft_forward_hooks(*args, **kwargs):
yield
return
Expand Down Expand Up @@ -940,17 +953,21 @@ def disable_adapter(self):
self.forward = self.base_model.forward
old_prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
self._adapters_disabled = True
yield
finally:
self.forward = old_forward
self.prepare_inputs_for_generation = old_prepare_inputs_for_generation
self._adapters_disabled = False

elif self.peft_config[self.active_adapter].is_adaption_prompt:
try:
self.base_model.disable_adapter_layers()
self._adapters_disabled = True
yield
finally:
self.base_model.enable_adapter_layers()
self._adapters_disabled = False

else: # LoRA, LoHa, etc.
model_status = self.get_model_status()
Expand All @@ -962,11 +979,13 @@ def disable_adapter(self):
)
try:
self.base_model.disable_adapter_layers()
self._adapters_disabled = True
yield
finally:
if model_status.enabled is not False:
# model_status.enabled is `True` or `"irregular"`
self.base_model.enable_adapter_layers()
self._adapters_disabled = False

def get_base_model(self) -> torch.nn.Module:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/cpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, config, word_embeddings):
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

self.embedding.requires_grad_(False)

# Initialize delta embedding with zero weights
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32)
Expand Down
38 changes: 37 additions & 1 deletion src/peft/tuners/lora/model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _enable_peft_forward_hooks is becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from torch import nn
from transformers.modeling_layers import GradientCheckpointingLayer

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import (
Expand Down Expand Up @@ -351,13 +352,48 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
alora_offsets = kwargs.pop("alora_offsets", None)

if adapter_names is None and alora_offsets is None:
# nothing to do
yield
return
hook_handles = []

if alora_offsets is not None:
for layer in self.modules():
for n, layer in self.named_modules():
# gradient checkpointing layer are executed concurrently to the 'normal' forward call
# (in the backward step the gradient checkpointing layer's forward will be executed again).
# this means that when the gradient checkpointing layer is called, the _enable_peft_forward_hooks
# context manager is long gone. to be consistent with the normal forward we need to register the pre
# hooks for this concurrent forward call as well.
#
# Note that this will lead to double application of whatever the callbacks do in normal forward.
# Make sure that whatever change is done, can be applied more than once without harm (idempotency).
if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing:

def forward_pre_hook(name, module, inputs, **kwargs):
for submodule in module.modules():
if isinstance(submodule, LoraLayer):
handle = submodule.register_forward_pre_hook(
partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]),
with_kwargs=True,
)
module._peft_gradient_checkpointing_forward_hooks.append(handle)

def backward_hook(name, module, *grad_output, **kwargs):
while module._peft_gradient_checkpointing_forward_hooks:
module._peft_gradient_checkpointing_forward_hooks.pop().remove()

if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []):
raise ValueError(
"Multiple invocations of PEFT forward hooks before .backward() with enabled gradient "
"checkpointing. Disable gradient checkpointing or only call forward once per backward."
)
layer._peft_gradient_checkpointing_forward_hooks = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for pre-existing _peft_gradient_checkpointing_forward_hooks? Right now, we know they don't exist, but since this is supposed to be a general solution (not aLoRA-specific), I'd say it's safer to check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep it as is and once there are more methods that require access we can move to method-specific entries in a dictionary instead of a global list.

handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets))
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
handle = layer.register_full_backward_hook(partial(backward_hook, n))
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
if isinstance(layer, LoraLayer):
pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets)
handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,11 @@ def test_training_custom_models_layer_indexing(self, test_name, model_id, config
pass

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_custom_models_gradient_checkpointing(
self, test_name, model_id, config_cls, config_kwargs, use_reentrant
):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,12 @@ def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwa

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
_skip_if_not_conv1d_supported(model_id, config_cls)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy())
self._test_training_gradient_checkpointing(
model_id, config_cls, config_kwargs.copy(), use_reentrant=use_reentrant
)

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,11 @@ def test_training_encoder_decoders_layer_indexing(self, model_id, config_cls, co

@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_encoder_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_encoder_decoders_gradient_checkpointing(
self, model_id, config_cls, config_kwargs, use_reentrant
):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def test_training_layer_indexing(self, model_id, config_cls, config_kwargs):

@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
skip_deberta_lora_tests(config_cls, model_id)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
55 changes: 52 additions & 3 deletions tests/test_lora_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM

from peft import LoraConfig, get_peft_model
from peft import LoraConfig, TaskType, get_peft_model
from peft.tuners.lora.layer import Conv1d as LoraConv1d
from peft.tuners.lora.layer import Conv2d as LoraConv2d
from peft.tuners.lora.layer import Embedding as LoraEmbedding
Expand All @@ -32,6 +33,8 @@
get_alora_offsets_for_generate,
)

from .testing_common import hub_online_once


# Custom model featuring embeddings and a 'visual stack'
class CustomModel(nn.Module):
Expand Down Expand Up @@ -73,6 +76,9 @@ def __init__(self, vocab_size: int = 10, hidden_dim: int = 8):
self.embed = nn.Embedding(vocab_size, hidden_dim)
self.linear = nn.Linear(hidden_dim, vocab_size)

def prepare_inputs_for_generation(self, *args, **kwargs):
return kwargs

def forward(self, X=None, embeds=None, num_beams=None, alora_offsets=None):
if X is not None:
embeds = self.embed(X)
Expand Down Expand Up @@ -181,7 +187,7 @@ class TestActivatedLora:
)
# Verify alora_offsets are calculated correctly
def test_calculate_alora_offsets(self, input_ids, alora_invocation_tokens, expected_offsets):
config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, alora_invocation_tokens=alora_invocation_tokens)
peft_config = {"default": config}

# compute offsets
Expand Down Expand Up @@ -233,7 +239,12 @@ def test_alora_activation_matches_base_until_invocation(self):
def test_input_embeds_warning(self):
transformers_class = MockTransformerWrapper
base_model = transformers_class.from_pretrained()
cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False)
cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["linear"],
alora_invocation_tokens=[2],
init_lora_weights=False,
)
lora_model = get_peft_model(base_model, cfg)
lora_model.eval()

Expand Down Expand Up @@ -265,3 +276,41 @@ def test_num_beams_error(self):
with torch.no_grad():
lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3])
assert "Beam search not yet supported for aLoRA." in str(e.value)

def test_gradient_checkpointing_double_forward_raises(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"

with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0])
lora_model = get_peft_model(base_model, cfg)
lora_model.train()

lora_model.prepare_model_for_gradient_checkpointing(lora_model)
lora_model.gradient_checkpointing_enable()

inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])}

lora_model.forward(**inputs)

with pytest.raises(ValueError, match="Multiple invocations of PEFT forward hooks.*"):
lora_model.forward(**inputs)

def test_gradient_checkpointing_dpo_doesnt_raise(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"

with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0])
lora_model = get_peft_model(base_model, cfg)
lora_model.train()

lora_model.prepare_model_for_gradient_checkpointing(lora_model)
lora_model.gradient_checkpointing_enable()

inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])}

with lora_model.disable_adapter():
lora_model.forward(**inputs)

lora_model.forward(**inputs)
46 changes: 42 additions & 4 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import pytest
import torch
import transformers
import yaml
from diffusers import StableDiffusionPipeline
from packaging import version
Expand Down Expand Up @@ -1315,41 +1316,78 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
# more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers
assert nb_trainable < nb_trainable_all

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True):
# Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not
# generate gradients since the adapter is never activated so this will be a no-op for this test. It is still
# a valid test but it might be confusing to see a test pass if it is not supposed to.

if config_cls == PrefixTuningConfig:
return pytest.skip(f"Test not applicable for {config_cls}")

if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("AdaLora with RoBERTa does not work correctly")

if "bart" in model_id.lower() and version.parse(transformers.__version__) <= version.parse("5.0"):
self.skipTest(
"Bart in transformers < 5.0 doesn't handle input sharing well enough. See transformers#41821"
)

if (config_cls == OFTConfig) and ("deberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("OFT with Deberta does not work correctly")

if "gptbigcode" in model_id.lower():
self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.")

with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)

if not getattr(model, "supports_gradient_checkpointing", False):
return pytest.skip(f"Model {model_id} does not support gradient checkpointing")

model.gradient_checkpointing_enable()
# Disable lora_dropout and friends to remove non-determinism in gradient creation
for key in list(config_kwargs.keys()):
if key.endswith("dropout"):
del config_kwargs[key]

config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]

# if we don't set this, gradient checkpointing is not activated.
model.train(True)

inputs = self.prepare_inputs_for_testing()

# check if `training` works
output = model(**inputs)[0]
# invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing;
# note we're squaring the output for bigger gradients
output = model(**inputs)[0] ** 2

loss = output.sum()
loss.backward()

non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0}

for name, param in params:
param.grad = None

# invocation with gradient checkpointing for comparison
model.prepare_model_for_gradient_checkpointing(model)
model.gradient_checkpointing_enable({"use_reentrant": use_reentrant})

output = model(**inputs)[0] ** 2

loss = output.sum()
loss.backward()

non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0}
assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing

for n, param in model.named_parameters():
if "prompt_encoder." in n: # prompt tuning methods
if not issubclass(config_cls, CPTConfig):
Expand Down
Loading