@@ -38,6 +38,7 @@ julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
38
38
dropout (A:: AbstractArray , p:: Real ; dims = :) = dropout (_rng_from_array (A), A, p; dims)
39
39
40
40
function dropout (rng:: AbstractRNG , A:: AbstractArray , p:: Real ; dims = :)
41
+ _rng_compat_array (rng, A)
41
42
T = float (eltype (A))
42
43
0 <= p <= 1 || throw (ArgumentError (" dropout expects a probability 0 <= p <= 1" ))
43
44
if p > 0
@@ -52,7 +53,7 @@ function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
52
53
end
53
54
54
55
"""
55
- dropout!(B, A, p; dims=: )
56
+ dropout!(B, A, p; [ dims] )
56
57
57
58
This does exactly `B .= dropout(A, p; dims)`,
58
59
or rather, it's the implementation of out-of-place [`dropout`](@ref).
@@ -62,6 +63,7 @@ dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_
62
63
function dropout! (rng:: AbstractRNG , dst:: AbstractArray , src:: AbstractArray , p:: Real ; dims= :)
63
64
size (dst) == size (src) || throw (DimensionMismatch (" dropout! expects output array the same size as input" ))
64
65
0 <= p <= 1 || throw (ArgumentError (" dropout expects a probability 0 <= p <= 1" ))
66
+ _rng_compat_array (rng, A)
65
67
if p > 0
66
68
pT = convert (real (eltype (dst)), p)
67
69
_dropout! (rng, dst, src, pT, dims)
@@ -155,3 +157,6 @@ _rng_from_array(::AbstractArray) = Random.default_rng()
155
157
156
158
@non_differentiable _rng_from_array (:: Any )
157
159
160
+ # This exists because `rand!(default_rng(), CUDA.rand(3))` ignores the RNG,
161
+ # and Flux would prefer an error. NNlibCUDA will overload it to produce that.
162
+ _rng_compat_array (:: AbstractRNG , :: AbstractArray ) = nothing
0 commit comments