@@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
66# AD-friendly helper for dividing monolithic RNN params into equally sized gates
77multigate (x:: AbstractArray , h, :: Val{N} ) where N = ntuple (n -> gate (x,h,n), N)
88
9- @adjoint function multigate ( x:: AbstractArray , h, c)
9+ function ChainRulesCore . rrule ( :: typeof (multigate), x:: AbstractArray , h, c)
1010 function multigate_pullback (dy)
11- dx = Zygote. _zero (x, eltype (x))
12- map (multigate (dx, h, c), dy) do dxᵢ, dyᵢ
13- dyᵢ != = nothing && (dxᵢ.= Zygote. accum .(dxᵢ, dyᵢ));
11+ dx = map! (zero, similar (x, float (eltype (x)), axes (x)), x)
12+ foreach (multigate (dx, h, c), dy) do dxᵢ, dyᵢ
13+ dyᵢ isa AbstractZero && return
14+ @. dxᵢ += dyᵢ
1415 end
15- return (dx, nothing , nothing )
16+ return (NoTangent (), dx, NoTangent (), NoTangent () )
1617 end
1718 return multigate (x, h, c), multigate_pullback
1819end
@@ -380,7 +381,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
380381GRUv3 (a... ; ka... ) = Recur (GRUv3Cell (a... ; ka... ))
381382Recur (m:: GRUv3Cell ) = Recur (m, m. state0)
382383
383-
384+ # TODO move to ChainRulesCore?
384385@adjoint function Broadcast. broadcasted (f:: Recur , args... )
385386 Zygote.∇map (__context__, f, args... )
386387end
0 commit comments