Skip to content

Commit db3cf91

Browse files
authored
Widen in _grad! (#66)
1 parent 2bf0efa commit db3cf91

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/destructure.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,22 @@ function _grad!(x, dx, off, flat::AbstractVector)
127127
x′, _ = functor(typeof(x), x)
128128
dx′, _ = functor(typeof(x), base(dx))
129129
off′, _ = functor(typeof(x), off)
130-
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
130+
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
131+
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
132+
end
131133
flat
132134
end
133-
function _grad!(x, dx, off::Integer, flat::AbstractVector)
134-
@views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes
135+
function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T
136+
dx_un = unthunk(dx)
137+
T2 = promote_type(T, eltype(dx_un))
138+
if T != T2 # then we must widen the type
139+
flat = copyto!(similar(flat, T2), flat)
140+
end
141+
@views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
135142
flat
136143
end
137-
_grad!(x, dx::Zero, off, flat::AbstractVector) = dx
138-
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity
144+
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat
145+
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity
139146

140147
# These are only needed for 2nd derivatives:
141148
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)

test/destructure.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,43 @@ end
180180
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
181181
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
182182
end
183+
184+
@testset "DiffEqFlux issue 699" begin
185+
# The gradient of `re` is a vector into which we accumulate contributions, and the issue
186+
# is that one contribution may have a wider type than `v`, especially for `Dual` numbers.
187+
v, re = destructure((x=Float32[1,2], y=Float32[3,4,5]))
188+
_, bk = Zygote.pullback(re, ones(Float32, 5))
189+
# Testing with `Complex` isn't ideal, but this was an error on 0.2.1.
190+
# If some upgrade inserts ProjectTo, this will fail, and can be changed:
191+
@test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],)
192+
193+
@test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent
194+
@test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],)
195+
@test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],)
196+
@test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],)
197+
end
198+
199+
#=
200+
201+
# Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657
202+
using ForwardDiff, Zygote, Flux, Optimisers, Test
203+
204+
y = Float32[0.8564646, 0.21083355]
205+
p = randn(Float32, 27);
206+
t = 1.5f0
207+
λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)]
208+
209+
model = Chain(x -> x .^ 3,
210+
Dense(2 => 5, tanh),
211+
Dense(5 => 2))
212+
213+
p,re = Optimisers.destructure(model)
214+
f(u, p, t) = re(p)(u)
215+
_dy, back = Zygote.pullback(y, p) do u, p
216+
vec(f(u, p, t))
217+
end
218+
tmp1, tmp2 = back(λ);
219+
tmp1
220+
@test tmp2 isa Vector{<:ForwardDiff.Dual}
221+
222+
=#

0 commit comments

Comments
 (0)