11istraining () = false
22
3- @adjoint istraining () = true , _ -> nothing
3+ ChainRulesCore . rrule ( :: typeof (istraining)) = true , _ -> ( NoTangent (),)
44
55_isactive (m) = isnothing (m. active) ? istraining () : m. active
66
@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3838end
3939dropout (x, p; kwargs... ) = dropout (rng_from_array (x), x, p; kwargs... )
4040
41- @adjoint function dropout (rng, x, p; dims= :, active:: Bool = true )
42- active || return x, Δ -> (Δ, nothing )
43- y = dropout_mask (rng, x, p, dims= dims)
44- return x .* y, Δ -> (nothing , Δ .* y, nothing )
45- end
46-
4741dropout_mask (rng:: CUDA.RNG , x:: CuArray , p; kwargs... ) = _dropout_mask (rng, x, p; kwargs... )
4842dropout_mask (rng, x:: CuArray , p; kwargs... ) =
4943 throw (ArgumentError (" x isa CuArray, but rng isa $(typeof (rng)) . dropout_mask only support CUDA.RNG for CuArrays." ))
@@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:)
5650end
5751
5852# TODO move this to NNlib
59- Zygote . ChainRulesCore. @non_differentiable dropout_mask (rng, x, p )
53+ ChainRulesCore. @non_differentiable dropout_mask (:: Any , :: Any , :: Any )
6054
6155"""
6256 Dropout(p; dims=:, rng = rng_from_array())
@@ -234,7 +228,8 @@ function _track_stats!(
234228 bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one (V))) .* σ²new
235229 return nothing
236230end
237- Zygote. @nograd _track_stats!
231+
232+ ChainRulesCore. @non_differentiable _track_stats! (:: Any... )
238233
239234"""
240235 BatchNorm(channels::Integer, λ=identity;
0 commit comments