|
180 | 180 | 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
|
181 | 181 | end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
|
182 | 182 | end
|
| 183 | + |
| 184 | +@testset "DiffEqFlux issue 699" begin |
| 185 | + # The gradient of `re` is a vector into which we accumulate contributions, and the issue |
| 186 | + # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. |
| 187 | + v, re = destructure((x=Float32[1,2], y=Float32[3,4,5])) |
| 188 | + _, bk = Zygote.pullback(re, ones(Float32, 5)) |
| 189 | + # Testing with `Complex` isn't ideal, but this was an error on 0.2.1. |
| 190 | + # If some upgrade inserts ProjectTo, this will fail, and can be changed: |
| 191 | + @test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],) |
| 192 | + |
| 193 | + @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent |
| 194 | + @test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],) |
| 195 | + @test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],) |
| 196 | + @test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],) |
| 197 | +end |
| 198 | + |
| 199 | +#= |
| 200 | +
|
| 201 | +# Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657 |
| 202 | +using ForwardDiff, Zygote, Flux, Optimisers, Test |
| 203 | +
|
| 204 | +y = Float32[0.8564646, 0.21083355] |
| 205 | +p = randn(Float32, 27); |
| 206 | +t = 1.5f0 |
| 207 | +λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)] |
| 208 | +
|
| 209 | +model = Chain(x -> x .^ 3, |
| 210 | + Dense(2 => 5, tanh), |
| 211 | + Dense(5 => 2)) |
| 212 | +
|
| 213 | +p,re = Optimisers.destructure(model) |
| 214 | +f(u, p, t) = re(p)(u) |
| 215 | +_dy, back = Zygote.pullback(y, p) do u, p |
| 216 | + vec(f(u, p, t)) |
| 217 | +end |
| 218 | +tmp1, tmp2 = back(λ); |
| 219 | +tmp1 |
| 220 | +@test tmp2 isa Vector{<:ForwardDiff.Dual} |
| 221 | +
|
| 222 | +=# |
0 commit comments