diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index c09d6db31..cdd2f2372 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -257,7 +257,7 @@ function Base.delete!(ps::Params, x) return ps end -Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order) +Base.Broadcast.broadcastable(ps::Params) = ps.order @adjoint function Broadcast.broadcasted(f::Function, ps::Params) f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`")) @@ -375,40 +375,6 @@ function Base.copy(gs::Grads) merge!(gs_new, gs) end -broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...) - -broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs) -broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs) - -function materialize!(gs1::Grads, gs2::Grads) - issetequal(gs1.params, gs2.params) || - throw(ArgumentError("Expected Grads objects with the same Params.")) - for p in gs1.params - gs1[p] = gs2[p] - end - return gs1 -end - - -function Base.map(f, gs1::Grads, gss::ADictOrGrads...) - gsout = Grads(IdDict{Any,Any}(), Params(gs1.params)) - return map!(f, gsout, gs1, gss...) -end - -function Base.map!(f, gsout::Grads, gss::ADictOrGrads...) - all(issetequal(gsout.params, keys(gs)) for gs in gss) || - throw(ArgumentError("map! expects Grads objects with the same Params.")) - for p in gsout.params - gsout[p] = f((_getformap(gs, p) for gs in gss)...) - end - return gsout -end - -function _getformap(gs, p) - g = gs[p] - isnothing(g) ? fill!(similar(p), 0) : g -end - function pullback(f, ps::Params) cx = Context{true}(nothing) y, back = _pullback(cx, f)