Skip to content

Commit c2cab6d

Browse files
committed
fix some other rules
1 parent 6162295 commit c2cab6d

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/extra_rules.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
1616
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
1717
end
1818

19-
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
19+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
2020
xs[i...], ∇getindex(xs, i)
2121
end
2222

@@ -220,26 +220,31 @@ struct BackMap{T}
220220
end
221221
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
222222
back_apply(x, y) = x(y)
223-
back_apply_zero(x) = x(Zero())
223+
back_apply_zero(x) = x(Zero()) # Zero is not defined
224224

225225
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
226226
a, b = unzip_tuple(map(BackMap(f), args))
227-
function back(Δ)
227+
function map_back(Δ)
228228
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
229229
(NoTangent(), sum(fs), xs)
230230
end
231-
function back::ZeroTangent)
232-
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
233-
(NoTangent(), sum(fs), xs)
234-
end
235-
a, back
231+
map_back::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
232+
# function back(Δ::ZeroTangent)
233+
# (fs, xs) = unzip_tuple(map(back_apply_zero, b))
234+
# (NoTangent(), sum(fs), xs)
235+
# end
236+
a, map_back
236237
end
237238

239+
ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())
240+
238241
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
239242
a, b = unzip_tuple(ntuple(BackMap(f), n))
240-
a, function (Δ)
243+
function ntuple_back(Δ)
241244
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
242245
end
246+
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
247+
a, ntuple_back
243248
end
244249

245250
function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}

src/stage1/generated.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
315315
end
316316

317317
# TODO: Temporary - make better
318-
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
318+
function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N}
319319
getindex(a, inds...), let
320320
EvenOddOdd{1, c_order(N)}(
321321
(@Base.constprop :aggressive Δ->begin
322322
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
323323
BB = zero(a)
324-
BB[inds...] = Δ
324+
BB[inds...] = unthunk(Δ)
325325
(NoTangent(), BB, map(x->NoTangent(), inds)...)
326326
end),
327327
(@Base.constprop :aggressive (_, Δ, _)->begin
@@ -334,6 +334,7 @@ struct tuple_back{M}; end
334334
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
335335
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
336336
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
337+
(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ))
337338

338339
function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
339340
Core.tuple(args...),

0 commit comments

Comments
 (0)