Skip to content

Commit 4832bb8

Browse files
mcabbottToucheSir
andauthored
add _rng_compat_array (#458)
* add _rng_compat_array * Update src/dropout.jl Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent 5f63dbf commit 4832bb8

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.14"
3+
version = "0.8.15"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/dropout.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
3838
dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)
3939

4040
function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
41+
_rng_compat_array(rng, A)
4142
T = float(eltype(A))
4243
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
4344
if p > 0
@@ -52,7 +53,7 @@ function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
5253
end
5354

5455
"""
55-
dropout!(B, A, p; dims=:)
56+
dropout!(B, A, p; [dims])
5657
5758
This does exactly `B .= dropout(A, p; dims)`,
5859
or rather, it's the implementation of out-of-place [`dropout`](@ref).
@@ -62,6 +63,7 @@ dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_
6263
function dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:)
6364
size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input"))
6465
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
66+
_rng_compat_array(rng, A)
6567
if p > 0
6668
pT = convert(real(eltype(dst)), p)
6769
_dropout!(rng, dst, src, pT, dims)
@@ -155,3 +157,6 @@ _rng_from_array(::AbstractArray) = Random.default_rng()
155157

156158
@non_differentiable _rng_from_array(::Any)
157159

160+
# This exists because `rand!(default_rng(), CUDA.rand(3))` ignores the RNG,
161+
# and Flux would prefer an error. NNlibCUDA will overload it to produce that.
162+
_rng_compat_array(::AbstractRNG, ::AbstractArray) = nothing

0 commit comments

Comments
 (0)