Skip to content

Commit 40031df

Browse files
committed
Use old custom grad
1 parent 3a81f8f commit 40031df

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lib/axon/compiler.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ defmodule Axon.Compiler do
889889

890890
if event? and mode? do
891891
if on_event == :backward do
892-
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
892+
Nx.Defn.Kernel.custom_grad(expr, fn _ans, g ->
893893
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
894894
[{expr, hooked_g}]
895895
end)

0 commit comments

Comments
 (0)