Skip to content

Gradient checkpoint with multiple adapters #2832

@NguyenRichard

Description

@NguyenRichard

I'm not sure if it can be considered as a bug since I might be using the library differently from how it's supposed to be used.

Context:

I have a PeftModel that need to be infered with 2 different inputs.
For each input I have a pretrained adapter that is frozen and a new adapter for finetuning.

My forward does:

for name, x in inputs:
   mypeft_model.base_model.set_adapter([name+'pretrain',name+'ft'])
   custom_set_pretrain_grad_false_ft_true() #Doing it because set_adapter force gradients to True cf 2759#issue-3363985341
   feature = mypeft_model(x)

(#2759 (comment))
Issue:

  1. if mypeft_model contains cp.checkpoint(mymodule, x), the backpropagation will not update properly the weight of the LoRA layers in my module either because it did not 'see the set_adapter' or it did not 'see the force grad'
  2. A work around I have found is to wrap the whole code inside the loop with a cp.checkpoint but it's super heavy on the memory as I have to store all in GPU until the end of the backbone (ViT-G 40 blocks transformers)

Question:
Is there anyway to 'provide' the context to the backpropagation even using gradient checkpointing when switching adapters in the forward?
I have not explored huggingface transformers.enable_gradient_checkpointing() since I'm using a custom model and I'm unsure if it fits for my problem.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions