diff --git a/test/derivatives/ArrayFunctionTests.jl b/test/derivatives/ArrayFunctionTests.jl index 3640235b..75dcf927 100644 --- a/test/derivatives/ArrayFunctionTests.jl +++ b/test/derivatives/ArrayFunctionTests.jl @@ -27,25 +27,42 @@ function testcat(f, args::Tuple, type, kwargs=NamedTuple()) @test value(x) == f(args...; kwargs...) else @assert length(args) == 2 - x = f(track(args[1]), args[2]; kwargs...) - @test x isa type - @test value(x) == f(args...; kwargs...) - x = f(args[1], track(args[2]); kwargs...) - @test x isa type - @test value(x) == f(args...; kwargs...) + broken = f == hcat && (args[2] isa AbstractMatrix) + if broken && VERSION >= v"1.4" + @test_broken f(track(args[1]), args[2]; kwargs...) isa type + @test_broken value(f(track(args[1]), args[2]; kwargs...)) == f(args...; kwargs...) + else + @test f(track(args[1]), args[2]; kwargs...) isa type + @test value(f(track(args[1]), args[2]; kwargs...)) == f(args...; kwargs...) + end + + broken = f == hcat && (args[1] isa AbstractMatrix) + if broken && VERSION >= v"1.4" + @test_broken f(args[1], track(args[2]); kwargs...) isa type + @test_broken value(f(args[1], track(args[2]); kwargs...)) == f(args...; kwargs...) + else + @test f(args[1], track(args[2]); kwargs...) isa type + @test value(f(args[1], track(args[2]); kwargs...)) == f(args...; kwargs...) + end end args = (args..., args...) - x = f(track.(args)...; kwargs...) - @test x isa type - @test value(x) == f(args...; kwargs...) - sizes = size.(args) + broken = (f in (vcat, hcat) && (args[2] isa AbstractArray)) + if broken && VERSION >= v"1.4" + @test_broken f(track.(args)...; kwargs...) isa type + @test_broken value(f(track.(args)...; kwargs...)) == f(args...; kwargs...) + else + @test f(track.(args)...; kwargs...) isa type + @test value(f(track.(args)...; kwargs...)) == f(args...; kwargs...) + end + F = vecx -> sum(f(unpack(sizes, vecx)...; kwargs...)) X = pack(args) @test ForwardDiff.gradient(F, X) == gradient(F, X) end + function pack(xs) return mapreduce(vcat, xs) do x x isa Number ? x : vec(x) diff --git a/test/runtests.jl b/test/runtests.jl index 5dbba762..358d7304 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,11 @@ +using Test + const TESTDIR = dirname(@__FILE__) test_println(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...") +@testset "ReverseDiff" begin + println("running TapeTests...") t = @elapsed include(joinpath(TESTDIR, "TapeTests.jl")) println("done (took $t seconds).") @@ -53,3 +57,5 @@ println("done (took $t seconds).") println("running CompatTests...") t = @elapsed include(joinpath(TESTDIR, "compat/CompatTests.jl")) println("done (took $t seconds).") + +end