diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 9b442c502..9d0ba20ba 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -392,9 +392,9 @@ end function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) - arr_axs = map(axes, arrs) + projects = map(ProjectTo, arrs) function add_pullback(dy) - return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...) + return (NoTangent(), map(project->project(dy), projects)...) end return y, add_pullback end diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index b25b16acf..3179e1a8d 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -185,5 +185,8 @@ @testset "addition" begin test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3))) + end end