From 1484c8d4b40ff40011be14d6a1efd2b0f484d78d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Jan 2022 19:18:03 +0000 Subject: [PATCH 1/2] use ProjectTo in Array addition --- src/rulesets/Base/arraymath.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 9b442c502..9d0ba20ba 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -392,9 +392,9 @@ end function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) - arr_axs = map(axes, arrs) + projects = map(ProjectTo, arrs) function add_pullback(dy) - return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...) + return (NoTangent(), map(project->project(dy), projects)...) end return y, add_pullback end From 6d44bb2d5497d89e99753c30bfe59bcd9a7e99fb Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Jan 2022 11:57:01 +0000 Subject: [PATCH 2/2] Add tests (failing) --- test/rulesets/Base/arraymath.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index b25b16acf..3179e1a8d 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -185,5 +185,8 @@ @testset "addition" begin test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3))) + end end