@@ -2,9 +2,14 @@ istraining() = false
22
33ChainRulesCore. rrule (:: typeof (istraining)) = true , _ -> (NoTangent (),)
44
5- _isactive (m) = isnothing ( m. active) ? istraining () : Bool (m . active )
5+ _isactive (m) = Bool ( something ( m. active, istraining ()) )
66
7- ChainRulesCore. @non_differentiable _isactive (:: Any )
7+ # Avoids instabilities from differentiating through getproperty(m, :active)
8+ function ChainRulesCore. rrule (:: typeof (_isactive), m)
9+ training, _ = rrule (istraining)
10+ _isactive_pullback (_) = (NoTangent (), NoTangent ())
11+ return Bool (something (m. active, training)), _isactive_pullback
12+ end
813
914_dropout_shape (s, :: Colon ) = size (s)
1015_dropout_shape (s, dims) = tuple ((i ∉ dims ? 1 : si for (i, si) ∈ enumerate (size (s))). .. )
5964
6065function (pb:: DropoutPullback )(dy)
6166 dx = pb. project (_apply_mask (dy, pb. mask))
62- return (NoTangent (), NoTangent (), dx, NoTangent ())
67+ return (NoTangent (), NoTangent (), dx, NoTangent (), NoTangent (), NoTangent () )
6368end
6469
6570_apply_mask (x, :: Nothing ) = x
0 commit comments