Skip to content

Commit 2c2563b

Browse files
ekojsalimwinglian
andauthored
fix: gradient checkpointing functools.partial object has no attribute __self__ (axolotl-ai-cloud#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error * lint * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
1 parent 5cb3398 commit 2c2563b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/axolotl/utils/gradient_checkpointing/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""custom checkpointing utils"""
22

3+
from functools import partial
4+
35
from axolotl.utils.gradient_checkpointing.unsloth import (
46
Unsloth_Offloaded_Gradient_Checkpointer,
57
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
911
decoder_layer, *args, use_reentrant=None
1012
): # pylint: disable=unused-argument
1113
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
12-
decoder_layer.__self__,
14+
(
15+
decoder_layer.func.__self__
16+
if isinstance(decoder_layer, partial)
17+
else decoder_layer.__self__
18+
),
1319
*args,
1420
)

0 commit comments

Comments
 (0)