diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 0805da91f..bbd45bbc6 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -60,10 +60,13 @@ end function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) - inds = ntuple(d -> d == dims ? p : (:), ndims(x)) function sortslices_pullback(dy) - return (NoTangent(), ∇getindex(x, unthunk(dy), inds...)) + # avoid closing over `inds` as it doesn't fully infer and that makes it worse + # recomputing is cheap + inds_inner = ntuple(d -> d == dims ? p : (:), ndims(x)) + return (NoTangent(), ∇getindex(x, unthunk(dy), inds_inner...)) end + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) return x[inds...], sortslices_pullback end diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 052045d1e..00b08efbb 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -26,7 +26,7 @@ test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) - test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) + test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum)) @test_throws Exception sortslices(Diagonal(1:3), dims=1) end