Skip to content

Commit 13ccc86

Browse files
authored
Explicitly unthunk in a few rules (#670)
* unthunk in rrule for + * unthunk in unbroadcast * v1.44.5
1 parent 39c2d17 commit 13ccc86

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.44.4"
3+
version = "1.44.5"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/arraymath.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)
415415
function rrule(::typeof(+), arrs::AbstractArray...)
416416
y = +(arrs...)
417417
arr_axs = map(axes, arrs)
418-
function add_pullback(dy)
418+
function add_pullback(dy_raw)
419+
dy = unthunk(dy_raw) # reshape will otherwise unthunk N times
419420
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
420421
end
421422
return y, add_pullback

src/rulesets/Base/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |
316316
# When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape.
317317
# It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient.
318318

319-
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
319+
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw)
320+
dx = unthunk(dx_raw)
320321
N = ndims(dx)
321322
if length(x) == length(dx)
322323
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
@@ -327,7 +328,8 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
327328
end
328329
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
329330

330-
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
331+
function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N}
332+
dx = unthunk(dx_raw)
331333
val = if N == length(dx)
332334
dx
333335
else

0 commit comments

Comments
 (0)