@@ -2,9 +2,14 @@ istraining() = false
22
33ChainRulesCore. rrule (:: typeof (istraining)) = true , _ -> (NoTangent (),)
44
5- _isactive (m) = isnothing ( m. active) ? istraining () : Bool (m . active )
5+ _isactive (m) = Bool ( something ( m. active, istraining ()) )
66
7- ChainRulesCore. @non_differentiable _isactive (:: Any )
7+ # Avoids instabilities from differentiating through getproperty(m, :active)
8+ function ChainRulesCore. rrule (:: typeof (_isactive), m)
9+ training, _ = rrule (istraining)
10+ _isactive_pullback (_) = (NoTangent (), NoTangent ())
11+ return Bool (something (m. active, training)), _isactive_pullback
12+ end
813
914_dropout_shape (s, :: Colon ) = size (s)
1015_dropout_shape (s, dims) = tuple ((i ∉ dims ? 1 : si for (i, si) ∈ enumerate (size (s))). .. )
5762
5863function (pb:: DropoutPullback )(dy)
5964 dx = pb. project (_apply_mask (dy, pb. mask))
60- return (NoTangent (), NoTangent (), dx, NoTangent ())
65+ return (NoTangent (), NoTangent (), dx, NoTangent (), NoTangent (), NoTangent () )
6166end
6267
6368_apply_mask (x, :: Nothing ) = x
0 commit comments