@@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
16
16
g (Δ), Δ′′-> (nothing , Δ′′[1 ][g. i... ])
17
17
end
18
18
19
- function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (getindex), xs:: Array , i... )
19
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (getindex), xs:: Array{<:Number} , i... )
20
20
xs[i... ], ∇getindex (xs, i)
21
21
end
22
22
@@ -220,26 +220,31 @@ struct BackMap{T}
220
220
end
221
221
(f:: BackMap{N} )(args... ) where {N} = ∂⃖¹ (getfield (f, :f ), args... )
222
222
back_apply (x, y) = x (y)
223
- back_apply_zero (x) = x (Zero ())
223
+ back_apply_zero (x) = x (Zero ()) # Zero is not defined
224
224
225
225
function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (map), f, args:: Tuple )
226
226
a, b = unzip_tuple (map (BackMap (f), args))
227
- function back (Δ)
227
+ function map_back (Δ)
228
228
(fs, xs) = unzip_tuple (map (back_apply, b, Δ))
229
229
(NoTangent (), sum (fs), xs)
230
230
end
231
- function back (Δ:: ZeroTangent )
232
- (fs, xs) = unzip_tuple (map (back_apply_zero, b))
233
- (NoTangent (), sum (fs), xs)
234
- end
235
- a, back
231
+ map_back (Δ:: AbstractZero ) = (NoTangent (), NoTangent (), NoTangent ())
232
+ # function back(Δ::ZeroTangent)
233
+ # (fs, xs) = unzip_tuple(map(back_apply_zero, b))
234
+ # (NoTangent(), sum(fs), xs)
235
+ # end
236
+ a, map_back
236
237
end
237
238
239
+ ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (map), f, args:: Tuple{} ) = (), _ -> (NoTangent (), NoTangent (), NoTangent ())
240
+
238
241
function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (Base. ntuple), f, n)
239
242
a, b = unzip_tuple (ntuple (BackMap (f), n))
240
- a, function (Δ)
243
+ function ntuple_back (Δ)
241
244
(NoTangent (), sum (map (back_apply, b, Δ)), NoTangent ())
242
245
end
246
+ ntuple_back (:: AbstractZero ) = (NoTangent (), NoTangent (), NoTangent ())
247
+ a, ntuple_back
243
248
end
244
249
245
250
function ChainRules. frule (:: DiffractorRuleConfig , _, :: Type{Vector{T}} , undef:: UndefInitializer , dims:: Int... ) where {T}
0 commit comments