-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
I'm not sure if it can be considered as a bug since I might be using the library differently from how it's supposed to be used.
Context:
I have a PeftModel that need to be infered with 2 different inputs.
For each input I have a pretrained adapter that is frozen and a new adapter for finetuning.
My forward does:
for name, x in inputs:
mypeft_model.base_model.set_adapter([name+'pretrain',name+'ft'])
custom_set_pretrain_grad_false_ft_true() #Doing it because set_adapter force gradients to True cf 2759#issue-3363985341
feature = mypeft_model(x)
(#2759 (comment))
Issue:
- if mypeft_model contains cp.checkpoint(mymodule, x), the backpropagation will not update properly the weight of the LoRA layers in my module either because it did not 'see the set_adapter' or it did not 'see the force grad'
- A work around I have found is to wrap the whole code inside the loop with a cp.checkpoint but it's super heavy on the memory as I have to store all in GPU until the end of the backbone (ViT-G 40 blocks transformers)
Question:
Is there anyway to 'provide' the context to the backpropagation even using gradient checkpointing when switching adapters in the forward?
I have not explored huggingface transformers.enable_gradient_checkpointing() since I'm using a custom model and I'm unsure if it fits for my problem.
Metadata
Metadata
Assignees
Labels
No labels