We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5cb3398 commit 2c2563bCopy full SHA for 2c2563b
src/axolotl/utils/gradient_checkpointing/__init__.py
@@ -1,5 +1,7 @@
1
"""custom checkpointing utils"""
2
3
+from functools import partial
4
+
5
from axolotl.utils.gradient_checkpointing.unsloth import (
6
Unsloth_Offloaded_Gradient_Checkpointer,
7
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
9
11
decoder_layer, *args, use_reentrant=None
10
12
): # pylint: disable=unused-argument
13
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
- decoder_layer.__self__,
14
+ (
15
+ decoder_layer.func.__self__
16
+ if isinstance(decoder_layer, partial)
17
+ else decoder_layer.__self__
18
+ ),
19
*args,
20
0 commit comments