From 1cfbf811e27c2a495aa1ac158b0c375b3771abda Mon Sep 17 00:00:00 2001 From: Adarshkumar712 Date: Sat, 7 Mar 2020 15:38:33 +0530 Subject: [PATCH 1/3] Updated doc and unicode(alpha,beta) --- src/activation.jl | 41 ++++++++++---------------------- src/conv.jl | 6 ++--- src/gemm.jl | 18 +++++++------- src/impl/conv_direct.jl | 32 ++++++++++++------------- src/impl/conv_im2col.jl | 37 +++++++++++++--------------- src/impl/depthwiseconv_direct.jl | 20 ++++++++-------- src/impl/depthwiseconv_im2col.jl | 22 ++++++++--------- src/impl/pooling_direct.jl | 30 +++++++++++------------ 8 files changed, 93 insertions(+), 113 deletions(-) diff --git a/src/activation.jl b/src/activation.jl index 55cca2b9b..651ce7ef9 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -23,10 +23,9 @@ end Segment-wise linear approximation of sigmoid See: [BinaryConnect: Training Deep Neural Networks withbinary weights during propagations](https://arxiv.org/pdf/1511.00363.pdf) """ -hardσ(x::Real, a=0.2) = oftype(x/1, max(zero(x/1), min(one(x/1), oftype(x/1,a) * x + oftype(x/1,0.5)))) +hardσ(x::Real, a=0.2) = oftype(x / 1, max(zero(x / 1), min(one(x / 1), oftype(x / 1, a) * x + oftype(x / 1, 0.5)))) const hardsigmoid = hardσ - """ logσ(x) @@ -43,16 +42,14 @@ Return `log(σ(x))` which is computed in a numerically stable way. logσ(x::Real) = -softplus(-x) const logsigmoid = logσ - """ hardtanh(x) = max(-1, min(1, x)) -Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh. +Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh See: (http://ronan.collobert.org/pub/matos/2004_phdthesis_lip6.pdf) """ hardtanh(x::Real) = max(-one(x), min( one(x), x)) - """ relu(x) = max(0, x) @@ -61,7 +58,6 @@ activation function. """ relu(x::Real) = max(zero(x), x) - """ leakyrelu(x, a=0.01) = max(a*x, x) @@ -69,7 +65,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne activation function. You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`. """ -leakyrelu(x::Real, a = oftype(x / 1, 0.01)) = max(a * x, x / one(x)) +leakyrelu(x::Real, a=0.01) = max(oftype(x / 1, a) * x, x / 1) """ relu6(x) = min(max(0, x), 6) @@ -102,8 +98,7 @@ Exponential Linear Unit activation function. See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289). You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. """ -elu(x::Real, α = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x) - one(x))) - +elu(x::Real, α = one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x) - one(x))) """ gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) @@ -119,7 +114,6 @@ function gelu(x::Real) h * x * (one(x) + tanh(λ * (x + α * x^3))) end - """ swish(x) = x * σ(x) @@ -128,16 +122,14 @@ See [Swish: a Self-Gated Activation Function](https://arxiv.org/pdf/1710.05941.p """ swish(x::Real) = x * σ(x) - """ lisht(x) = x * tanh(x) -Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function +Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function. See [LiSHT](https://arxiv.org/abs/1901.05894) """ lisht(x::Real) = x * tanh(x) - """ selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1)) @@ -150,29 +142,25 @@ See [Self-Normalizing Neural Networks](https://arxiv.org/pdf/1706.02515.pdf). function selu(x::Real) λ = oftype(x / 1, 1.0507009873554804934193349852946) α = oftype(x / 1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, x / one(x), α * (exp(x) - one(x))) + λ * ifelse(x > 0, x / 1, α * (exp(x) - one(x))) end """ celu(x, α=1) = (x ≥ 0 ? x : α * (exp(x/α) - 1)) -Continuously Differentiable Exponential Linear Units See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf). """ -celu(x::Real, α::Real = one(x)) = ifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) - +celu(x::Real, α::Real = one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x/α) - one(x))) """ - trelu(x, theta = 1.0) = x > theta ? x : 0 + trelu(x, θ=1.0) = x > θ ? x : 0 -Threshold Gated Rectified Linear -See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf) +See [Threshold Gated Rectified Linear Unit](https://arxiv.org/pdf/1402.3337.pdf) """ -trelu(x::Real,theta = one(x)) = ifelse(x> theta, x, zero(x)) +trelu(x::Real,θ = one(x)) = ifelse(x> θ, x, zero(x)) const thresholdrelu = trelu - """ softsign(x) = x / (1 + |x|) @@ -180,7 +168,6 @@ See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal """ softsign(x::Real) = x / (one(x) + abs(x)) - """ softplus(x) = log(exp(x) + 1) @@ -188,19 +175,17 @@ See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glo """ softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) - """ - logcosh(x) + logcosh(x) = x + softplus(-2x) - log(2) Return `log(cosh(x))` which is computed in a numerically stable way. """ logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2)) - """ mish(x) = x * tanh(softplus(x)) -Self Regularized Non-Monotonic Neural Activation Function +Self Regularized Non-Monotonic Neural Activation Function. See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). """ mish(x::Real) = x * tanh(softplus(x)) @@ -218,7 +203,7 @@ tanhshrink(x::Real) = x - tanh(x) See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function) """ -softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ) +softshrink(x::Real, λ = oftype(x / 1, 0.5)) = min(max(zero(x), x - λ), x + λ) # Provide an informative error message if activation functions are called with an array for f in (:σ, :σ_stable, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) diff --git a/src/conv.jl b/src/conv.jl index 3a5d83d56..880ed5e6c 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -29,7 +29,7 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter! # First, we will define mappings from the generic API names to our accelerated backend -# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using +# implementations. For homogeneous-datatype 1d, 2d and 3d convolutions, we default to using # im2col + GEMM. Do so in a loop, here: for (front_name, backend) in ( # This maps from public, front-facing name, to internal backend name @@ -86,7 +86,7 @@ end # We always support a fallback, non-accelerated path, where we use the direct, but # slow, implementations. These should not typically be used, hence the `@debug`, -# but let's ggo ahead and define them first: +# but let's go ahead and define them first: for front_name in (:conv, :∇conv_data, :∇conv_filter, :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) @eval begin @@ -179,8 +179,6 @@ function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=fa end - - """ depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false) diff --git a/src/gemm.jl b/src/gemm.jl index 3a66b3651..e80aa1908 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -9,9 +9,9 @@ using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc Low-level gemm!() call with pointers, borrowed from Knet.jl -Calculates `C = alpha*op(A)*op(B) + beta*C`, where: +Calculates `C = α*op(A)*op(B) + β*C`, where: - `transA` and `transB` set `op(X)` to be either `identity()` or `transpose()` - - alpha and beta are scalars + - α and β are scalars - op(A) is an (M, K) matrix - op(B) is a (K, N) matrix - C is an (M, N) matrix. @@ -29,8 +29,8 @@ for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function gemm!(transA::Val, transB::Val, M::Int, N::Int, K::Int, - alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, - beta::$(elt), C::Ptr{$elt}) + α::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, + β::$(elt), C::Ptr{$elt}) # Convert our compile-time transpose marker to a char for BLAS convtrans(V::Val{false}) = 'N' convtrans(V::Val{true}) = 'T' @@ -52,7 +52,7 @@ for (gemm, elt) in gemm_datatype_mappings Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), convtrans(transA), convtrans(transB), M, N, K, - alpha, A, lda, B, ldb, beta, C, ldc) + α, A, lda, B, ldb, β, C, ldc) end end end @@ -61,10 +61,10 @@ for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function batched_gemm!(transA::AbstractChar, transB::AbstractChar, - alpha::($elt), + α::($elt), A::AbstractArray{$elt, 3}, B::AbstractArray{$elt, 3}, - beta::($elt), + β::($elt), C::AbstractArray{$elt, 3}) @assert !Base.has_offset_axes(A, B, C) @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch" @@ -90,8 +90,8 @@ for (gemm, elt) in gemm_datatype_mappings Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, - ka, alpha, ptrA, max(1,Base.stride(A,2)), - ptrB, max(1,Base.stride(B,2)), beta, ptrC, + ka, α, ptrA, max(1,Base.stride(A,2)), + ptrB, max(1,Base.stride(B,2)), β, ptrC, max(1,Base.stride(C,2))) ptrA += size(A, 1) * size(A, 2) * sizeof($elt) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 2e2dada2f..d02fbebf2 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -18,7 +18,7 @@ function clamp_hi(x, w, L) end """ - conv_direct!(y, x, w, cdims; alpha=1, beta=0) + conv_direct!(y, x, w, cdims; α=1, β=0) Direct convolution implementation; used for debugging, tests, and mixing/matching of strange datatypes within a single convolution. Uses naive nested for loop implementation @@ -29,14 +29,14 @@ so that if the user really wants to convolve an image of `UInt8`'s with a `Float kernel, storing the result in a `Float32` output, there is at least a function call for that madness. -The keyword arguments `alpha` and `beta` control accumulation behavior; this function -calculates `y = alpha * x * w + beta * y`, therefore by setting `beta` to a nonzero -value, the user is able to accumulate values into a preallocated `y` buffer, or by -setting `alpha` to a nonunitary value, an arbitrary gain factor can be applied. +The keyword arguments `α` and `β` control accumulation behavior; this function +calculates `y = α * x * w + β * y`, therefore by setting `β` to a non-zero +value, the user is able to accumulate values into a pre-allocated `y` buffer, or by +setting `α` to a non-unitary value, an arbitrary gain factor can be applied. -By defaulting `beta` to `false`, we make use of the Bradbury promotion trick to override +By defaulting `β` to `false`, we make use of the Bradbury promotion trick to override `NaN`'s that may pre-exist within our output buffer, as `false*NaN == 0.0`, whereas -`0.0*NaN == NaN`. Only set `beta` if you are certain that none of the elements within +`0.0*NaN == NaN`. Only set `β` if you are certain that none of the elements within `y` are `NaN`. The basic implementation performs 3-dimensional convolution; 1-dimensional and 2- @@ -47,7 +47,7 @@ conv_direct! function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; - alpha::yT = yT(1), beta = false) where {yT, xT, wT} + α::yT = yT(1), β = false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) @@ -95,7 +95,7 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, c_in, c_out] dotprod = muladd(x_val, w_val, dotprod) end - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @@ -138,7 +138,7 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, end end - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] end return y @@ -146,7 +146,7 @@ end ## Gradient definitions """ - ∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) + ∇conv_data_direct!(dx, dy, w, cdims; α=1, β=0) Calculate the gradient imposed upon `x` in the convolution `y = x * w`. """ @@ -154,18 +154,18 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`. function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; - alpha::xT=xT(1), beta=false) where {xT, yT, wT} + α::xT=xT(1), β=false) where {xT, yT, wT} w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = predilate(dy, stride(cdims)) ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims), dilation=dilation(cdims), flipkernel=flipkernel(cdims)) - dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta) + dx = conv_direct!(dx, dy, w, ctdims; α=α, β=β) return dx end """ - ∇conv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) + ∇conv_filter_direct!(dw, x, dy, cdims; α=1, β=0) Calculate the gradient imposed upon `w` in the convolution `y = x * w`. """ @@ -173,12 +173,12 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`. function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DenseConvDims; - alpha::wT=wT(1), beta=false) where {xT, yT, wT} + α::wT=wT(1), β=false) where {xT, yT, wT} x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) - conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) + conv_direct!(dw, dy, x, ctdims; α=α, β=β) if flipkernel(cdims) dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] end diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index e06231325..8f4f83492 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -11,13 +11,13 @@ end end """ - conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) + conv_im2col!(y, x, w, cdims, col=similar(x); α=1, β=0) Perform a convolution using im2col and GEMM, store the result in `y`. The kwargs -`alpha` and `beta` control accumulation behavior; internally this operation is -implemented as a matrix multiply that boils down to `y = alpha * x * w + beta * y`, thus -by setting `beta` to a nonzero value, multiple results can be accumulated into `y`, or -by setting `alpha` to a nonunitary value, various gain factors can be applied. +`α` and `β` control accumulation behavior; internally this operation is +implemented as a matrix multiply that boils down to `y = α * x * w + β * y`, thus +by setting `β` to a non-zero value, multiple results can be accumulated into `y`, or +by setting `α` to a non-unitary value, various gain factors can be applied. Note for the particularly performance-minded, you can provide a pre-allocated `col`, which should eliminate any need for large allocations within this method. @@ -26,7 +26,7 @@ function conv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where {T} + α::T=T(1), β::T=T(0)) where {T} check_dims(size(x), size(w), size(y), cdims) # COL * W -> Y @@ -39,7 +39,7 @@ function conv_im2col!( # In english, we're grabbing each input patch and laying them out along # the M dimension in `col`, so that the GEMM call below multiplies each # kernel (which is kernel_h * kernel_w * channels_in elments long) is - # dotproducted with that input patch, effectively computing a convolution + # dot-producted with that input patch, effectively computing a convolution # in a somewhat memory-wasteful but easily-computed way (since we already # have an extremely highly-optimized GEMM call available in BLAS). M = prod(output_size(cdims)) @@ -55,14 +55,14 @@ function conv_im2col!( col_ptr = pointer(col_slice) w_ptr = pointer(w) y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) + gemm!(Val(false), Val(false), M, N, K, α, col_ptr, w_ptr, β, y_ptr) end end return y end """ - ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0) + ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); α=1, β=0) Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -71,7 +71,7 @@ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where {T} + α::T=T(1), β::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) # COL' * dY -> dW @@ -104,18 +104,18 @@ function ∇conv_filter_im2col!( col_ptr = pointer(col_slice) dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1) dw_ptr = pointer(dw) - gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, α, col_ptr, dy_ptr, β, dw_ptr) end - # Because we accumulate over batches in this loop, we must set `beta` equal + # Because we accumulate over batches in this loop, we must set `β` equal # to `1.0` from this point on. - beta = T(1) + β = T(1) end return dw end """ - ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); α=1, β=0) Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. See the documentation for `conv_im2col!()` for explanation of other parameters. @@ -124,7 +124,7 @@ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where {T} + α::T=T(1), β::T=T(0)) where {T} check_dims(size(dx), size(w), size(dy), cdims) # dY W' -> dX @@ -154,7 +154,7 @@ function ∇conv_data_im2col!( dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) w_ptr = pointer(w) col_ptr = pointer(col_slice) - gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) + gemm!(Val(false), Val(true), M, N, K, α, dy_ptr, w_ptr, T(0), col_ptr) end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) end @@ -162,9 +162,6 @@ function ∇conv_data_im2col!( end - - - """ im2col!(col, x, cdims) @@ -233,7 +230,7 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, end end - + # For each "padded region", we run the fully general version @inbounds for (w_region, h_region, d_region) in padded_regions for c in 1:C_in, diff --git a/src/impl/depthwiseconv_direct.jl b/src/impl/depthwiseconv_direct.jl index b6822a488..40197f386 100644 --- a/src/impl/depthwiseconv_direct.jl +++ b/src/impl/depthwiseconv_direct.jl @@ -1,7 +1,7 @@ ## This file contains direct Julia implementations of depwthwise convolutions """ - depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) + depthwiseconv_direct!(y, x, w, cdims; α=1, β=0) Direct depthwise convolution implementation; used for debugging, tests, and mixing/ matching of strange datatypes within a single convolution. Uses naive nested for loop @@ -20,7 +20,7 @@ See the docstring for `conv_direct!()` for more on the optional parameters. """ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - alpha::yT=yT(1), beta=false) where {yT, xT, wT} + α::yT=yT(1), β=false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) @@ -69,7 +69,7 @@ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, c_mult, c_in] dotprod = muladd(x_val, w_val, dotprod) end - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @@ -114,14 +114,14 @@ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, end end - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] end return y end """ - ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) + ∇depthwiseconv_data_direct!(dx, dy, w, cdims; α=1, β=0) Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. We make use of the fact that a depthwise convolution is equivalent to `C_in` separate @@ -135,7 +135,7 @@ for each batch and channel independently. function ∇depthwiseconv_data_direct!( dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - alpha::xT=xT(1), beta=false) where {xT, yT, wT} + α::xT=xT(1), β=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution @@ -153,13 +153,13 @@ function ∇depthwiseconv_data_direct!( ) ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; - alpha=alpha, beta=beta) + α=α, β=β) end return dx end """ - ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) + ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; α=1, β=0) Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. """ @@ -168,7 +168,7 @@ Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w` function ∇depthwiseconv_filter_direct!( dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; - alpha::wT=wT(1),beta=false) where {xT, yT, wT} + α::wT=wT(1),β=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution @@ -186,7 +186,7 @@ function ∇depthwiseconv_filter_direct!( ) ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; - alpha=alpha, beta=beta) + α=α, β=β) dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) end return dw diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index 145dc9961..ec4b42aaf 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -2,7 +2,7 @@ """ - depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) + depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); α=1, β=0) Perform a depthwise convolution using im2col and GEMM, store the result in `y`. @@ -14,7 +14,7 @@ function depthwiseconv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where T + α::T=T(1), β::T=T(0)) where T check_dims(size(x), size(w), size(y), cdims) # This functions exactly the same as conv_im2col!(), except that we shard the @@ -40,7 +40,7 @@ function depthwiseconv_im2col!( col_ptr = pointer(col_slice, (c_in-1)*M*K+1) w_ptr = pointer(w, (c_in-1)*K*N+1) y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) + gemm!(Val(false), Val(false), M, N, K, α, col_ptr, w_ptr, β, y_ptr) end end end @@ -48,7 +48,7 @@ function depthwiseconv_im2col!( end """ - ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) + ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); α=1, β) Depthwise conv2d backward pass onto the weights using im2col and GEMM. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -59,7 +59,7 @@ function ∇depthwiseconv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where T + α::T=T(1), β::T=T(0)) where T check_dims(size(x), size(dw), size(dy), cdims) M = prod(kernel_size(cdims)) @@ -78,19 +78,19 @@ function ∇depthwiseconv_filter_im2col!( col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1) dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) - gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, α, col_ptr, dy_ptr, β, dw_ptr) end end - # Because we accumulate over batches in this loop, we must set `beta` equal + # Because we accumulate over batches in this loop, we must set `β` equal # to `1.0` from this point on. - beta = T(1) + β = T(1) end return dw end """ - depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) + depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); α=1, β=0) Depwthwise conv2d backward pass onto the input using im2col and GEMM. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -101,7 +101,7 @@ function ∇depthwiseconv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), - alpha::T=T(1), beta::T=T(0)) where T + α::T=T(1), β::T=T(0)) where T check_dims(size(dx), size(w), size(dy), cdims) M = prod(output_size(cdims)) @@ -119,7 +119,7 @@ function ∇depthwiseconv_data_im2col!( dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) w_ptr = pointer(w, (cidx - 1)*K*N + 1) col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1) - gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) + gemm!(Val(false), Val(true), M, N, K, α, dy_ptr, w_ptr, T(0), col_ptr) end end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) diff --git a/src/impl/pooling_direct.jl b/src/impl/pooling_direct.jl index f95ab32f5..e5bc838ae 100644 --- a/src/impl/pooling_direct.jl +++ b/src/impl/pooling_direct.jl @@ -5,7 +5,7 @@ using Statistics for name in (:max, :mean) @eval function $((Symbol("$(name)pool_direct!")))( y::AbstractArray{T,5}, x::AbstractArray{T,5}, - pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} + pdims::PoolDims; α::T = T(1), β::T = T(0)) where {T} check_dims(size(x), size(y), pdims) width, height, depth = input_size(pdims) @@ -24,9 +24,9 @@ for name in (:max, :mean) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling it - # into the `alpha` multiplier. + # into the `α` multiplier. if $(name == :mean) - alpha = alpha/prod(kernel_size(pdims)) + α = α/prod(kernel_size(pdims)) end # Each loop, we initialize `m` to something, set that here. @@ -66,7 +66,7 @@ for name in (:max, :mean) error("Unimplemented codegen path") end end - y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] + y[w, h, d, c, batch_idx] = α*m + β*y[w, h, d, c, batch_idx] end # Next, the padded regions @@ -111,7 +111,7 @@ for name in (:max, :mean) end end end - y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] + y[w, h, d, c, batch_idx] = α*m + β*y[w, h, d, c, batch_idx] end end @@ -124,7 +124,7 @@ for name in (:max, :mean) @eval function $((Symbol("∇$(name)pool_direct!")))( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, y::AbstractArray{T,5}, x::AbstractArray{T,5}, - pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} + pdims::PoolDims; α::T = T(1), β::T = T(0)) where {T} check_dims(size(x), size(dy), pdims) width, height, depth = input_size(pdims) @@ -143,9 +143,9 @@ for name in (:max, :mean) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling - # it into the `alpha` multiplier. + # it into the `α` multiplier. if $(name == :mean) - alpha = alpha/prod(kernel_size(pdims)) + α = α/prod(kernel_size(pdims)) end # Start with the central region @@ -176,15 +176,15 @@ for name in (:max, :mean) # If it's equal; this is the one we chose. We only choose one per # kernel window, all other elements of dx must be zero. if y_idx == x[x_idxs...] && !maxpool_already_chose - dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] + dx[x_idxs...] = dy_idx*α + β*dx[x_idxs...] maxpool_already_chose = true - # Maxpooling does not support `beta` right now. :( + # Maxpooling does not support `β` right now. :( #else - # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] + # dx[x_idxs...] = T(0) + β*dx[x_idxs...] end elseif $(name == :mean) # Either does meanpool :( - dx[x_idxs...] = dy_idx*alpha + dx[x_idxs...] + dx[x_idxs...] = dy_idx*α + dx[x_idxs...] else error("Unimplemented codegen path") end @@ -228,13 +228,13 @@ for name in (:max, :mean) x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) if $(name == :max) if y_idx == x[x_idxs...] && !maxpool_already_chose - dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] + dx[x_idxs...] = dy_idx*α + β*dx[x_idxs...] maxpool_already_chose = true #else - # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] + # dx[x_idxs...] = T(0) + β*dx[x_idxs...] end elseif $(name == :mean) - dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...] + dx[x_idxs...] += dy_idx*α + β*dx[x_idxs...] else error("Unimplemented codegen path") end From d74860d8ea97a6cb90d6ee7e3cf33f34d6e72c24 Mon Sep 17 00:00:00 2001 From: Adarshkumar712 Date: Mon, 6 Apr 2020 14:41:51 +0530 Subject: [PATCH 2/3] remove unicode for alpha and beta --- src/gemm.jl | 18 +++++++++--------- src/impl/conv_direct.jl | 30 +++++++++++++++--------------- src/impl/conv_im2col.jl | 30 +++++++++++++++--------------- src/impl/depthwiseconv_direct.jl | 20 ++++++++++---------- src/impl/depthwiseconv_im2col.jl | 22 +++++++++++----------- src/impl/pooling_direct.jl | 30 +++++++++++++++--------------- 6 files changed, 75 insertions(+), 75 deletions(-) diff --git a/src/gemm.jl b/src/gemm.jl index e80aa1908..3a66b3651 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -9,9 +9,9 @@ using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc Low-level gemm!() call with pointers, borrowed from Knet.jl -Calculates `C = α*op(A)*op(B) + β*C`, where: +Calculates `C = alpha*op(A)*op(B) + beta*C`, where: - `transA` and `transB` set `op(X)` to be either `identity()` or `transpose()` - - α and β are scalars + - alpha and beta are scalars - op(A) is an (M, K) matrix - op(B) is a (K, N) matrix - C is an (M, N) matrix. @@ -29,8 +29,8 @@ for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function gemm!(transA::Val, transB::Val, M::Int, N::Int, K::Int, - α::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, - β::$(elt), C::Ptr{$elt}) + alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, + beta::$(elt), C::Ptr{$elt}) # Convert our compile-time transpose marker to a char for BLAS convtrans(V::Val{false}) = 'N' convtrans(V::Val{true}) = 'T' @@ -52,7 +52,7 @@ for (gemm, elt) in gemm_datatype_mappings Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), convtrans(transA), convtrans(transB), M, N, K, - α, A, lda, B, ldb, β, C, ldc) + alpha, A, lda, B, ldb, beta, C, ldc) end end end @@ -61,10 +61,10 @@ for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function batched_gemm!(transA::AbstractChar, transB::AbstractChar, - α::($elt), + alpha::($elt), A::AbstractArray{$elt, 3}, B::AbstractArray{$elt, 3}, - β::($elt), + beta::($elt), C::AbstractArray{$elt, 3}) @assert !Base.has_offset_axes(A, B, C) @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch" @@ -90,8 +90,8 @@ for (gemm, elt) in gemm_datatype_mappings Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, - ka, α, ptrA, max(1,Base.stride(A,2)), - ptrB, max(1,Base.stride(B,2)), β, ptrC, + ka, alpha, ptrA, max(1,Base.stride(A,2)), + ptrB, max(1,Base.stride(B,2)), beta, ptrC, max(1,Base.stride(C,2))) ptrA += size(A, 1) * size(A, 2) * sizeof($elt) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index d02fbebf2..617d69103 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -18,7 +18,7 @@ function clamp_hi(x, w, L) end """ - conv_direct!(y, x, w, cdims; α=1, β=0) + conv_direct!(y, x, w, cdims; alpha=1, beta=0) Direct convolution implementation; used for debugging, tests, and mixing/matching of strange datatypes within a single convolution. Uses naive nested for loop implementation @@ -29,14 +29,14 @@ so that if the user really wants to convolve an image of `UInt8`'s with a `Float kernel, storing the result in a `Float32` output, there is at least a function call for that madness. -The keyword arguments `α` and `β` control accumulation behavior; this function -calculates `y = α * x * w + β * y`, therefore by setting `β` to a non-zero +The keyword arguments `alpha` and `beta` control accumulation behavior; this function +calculates `y = alpha * x * w + beta * y`, therefore by setting `beta` to a non-zero value, the user is able to accumulate values into a pre-allocated `y` buffer, or by -setting `α` to a non-unitary value, an arbitrary gain factor can be applied. +setting `alpha` to a non-unitary value, an arbitrary gain factor can be applied. -By defaulting `β` to `false`, we make use of the Bradbury promotion trick to override +By defaulting `beta` to `false`, we make use of the Bradbury promotion trick to override `NaN`'s that may pre-exist within our output buffer, as `false*NaN == 0.0`, whereas -`0.0*NaN == NaN`. Only set `β` if you are certain that none of the elements within +`0.0*NaN == NaN`. Only set `beta` if you are certain that none of the elements within `y` are `NaN`. The basic implementation performs 3-dimensional convolution; 1-dimensional and 2- @@ -47,7 +47,7 @@ conv_direct! function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; - α::yT = yT(1), β = false) where {yT, xT, wT} + alpha::yT = yT(1), beta = false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) @@ -95,7 +95,7 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, c_in, c_out] dotprod = muladd(x_val, w_val, dotprod) end - y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @@ -138,7 +138,7 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, end end - y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end return y @@ -146,7 +146,7 @@ end ## Gradient definitions """ - ∇conv_data_direct!(dx, dy, w, cdims; α=1, β=0) + ∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `x` in the convolution `y = x * w`. """ @@ -154,18 +154,18 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`. function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; - α::xT=xT(1), β=false) where {xT, yT, wT} + alpha::xT=xT(1), beta=false) where {xT, yT, wT} w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = predilate(dy, stride(cdims)) ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims), dilation=dilation(cdims), flipkernel=flipkernel(cdims)) - dx = conv_direct!(dx, dy, w, ctdims; α=α, β=β) + dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta) return dx end """ - ∇conv_filter_direct!(dw, x, dy, cdims; α=1, β=0) + ∇conv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `w` in the convolution `y = x * w`. """ @@ -173,12 +173,12 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`. function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DenseConvDims; - α::wT=wT(1), β=false) where {xT, yT, wT} + alpha::wT=wT(1), beta=false) where {xT, yT, wT} x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) - conv_direct!(dw, dy, x, ctdims; α=α, β=β) + conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) if flipkernel(cdims) dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :] end diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index 8f4f83492..eb8f36ad5 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -11,13 +11,13 @@ end end """ - conv_im2col!(y, x, w, cdims, col=similar(x); α=1, β=0) + conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a convolution using im2col and GEMM, store the result in `y`. The kwargs -`α` and `β` control accumulation behavior; internally this operation is -implemented as a matrix multiply that boils down to `y = α * x * w + β * y`, thus -by setting `β` to a non-zero value, multiple results can be accumulated into `y`, or -by setting `α` to a non-unitary value, various gain factors can be applied. +`alpha` and `beta` control accumulation behavior; internally this operation is +implemented as a matrix multiply that boils down to `y = alpha * x * w + beta * y`, thus +by setting `beta` to a non-zero value, multiple results can be accumulated into `y`, or +by setting `alpha` to a non-unitary value, various gain factors can be applied. Note for the particularly performance-minded, you can provide a pre-allocated `col`, which should eliminate any need for large allocations within this method. @@ -26,7 +26,7 @@ function conv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where {T} + alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(w), size(y), cdims) # COL * W -> Y @@ -55,14 +55,14 @@ function conv_im2col!( col_ptr = pointer(col_slice) w_ptr = pointer(w) y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, α, col_ptr, w_ptr, β, y_ptr) + gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end return y end """ - ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); α=1, β=0) + ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw); alpha=1, beta=0) Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -71,7 +71,7 @@ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where {T} + alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) # COL' * dY -> dW @@ -104,18 +104,18 @@ function ∇conv_filter_im2col!( col_ptr = pointer(col_slice) dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1) dw_ptr = pointer(dw) - gemm!(Val(true), Val(false), M, N, K, α, col_ptr, dy_ptr, β, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) end - # Because we accumulate over batches in this loop, we must set `β` equal + # Because we accumulate over batches in this loop, we must set `beta` equal # to `1.0` from this point on. - β = T(1) + beta = T(1) end return dw end """ - ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); α=1, β=0) + ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. See the documentation for `conv_im2col!()` for explanation of other parameters. @@ -124,7 +124,7 @@ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where {T} + alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(dx), size(w), size(dy), cdims) # dY W' -> dX @@ -154,7 +154,7 @@ function ∇conv_data_im2col!( dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) w_ptr = pointer(w) col_ptr = pointer(col_slice) - gemm!(Val(false), Val(true), M, N, K, α, dy_ptr, w_ptr, T(0), col_ptr) + gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) end diff --git a/src/impl/depthwiseconv_direct.jl b/src/impl/depthwiseconv_direct.jl index 40197f386..b6822a488 100644 --- a/src/impl/depthwiseconv_direct.jl +++ b/src/impl/depthwiseconv_direct.jl @@ -1,7 +1,7 @@ ## This file contains direct Julia implementations of depwthwise convolutions """ - depthwiseconv_direct!(y, x, w, cdims; α=1, β=0) + depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) Direct depthwise convolution implementation; used for debugging, tests, and mixing/ matching of strange datatypes within a single convolution. Uses naive nested for loop @@ -20,7 +20,7 @@ See the docstring for `conv_direct!()` for more on the optional parameters. """ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - α::yT=yT(1), β=false) where {yT, xT, wT} + alpha::yT=yT(1), beta=false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) @@ -69,7 +69,7 @@ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, c_mult, c_in] dotprod = muladd(x_val, w_val, dotprod) end - y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @@ -114,14 +114,14 @@ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, end end - y[w_idx, h_idx, d_idx, c_out, batch] = α*dotprod + β*y[w_idx, h_idx, d_idx, c_out, batch] + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end return y end """ - ∇depthwiseconv_data_direct!(dx, dy, w, cdims; α=1, β=0) + ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. We make use of the fact that a depthwise convolution is equivalent to `C_in` separate @@ -135,7 +135,7 @@ for each batch and channel independently. function ∇depthwiseconv_data_direct!( dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - α::xT=xT(1), β=false) where {xT, yT, wT} + alpha::xT=xT(1), beta=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution @@ -153,13 +153,13 @@ function ∇depthwiseconv_data_direct!( ) ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; - α=α, β=β) + alpha=alpha, beta=beta) end return dx end """ - ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; α=1, β=0) + ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. """ @@ -168,7 +168,7 @@ Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w` function ∇depthwiseconv_filter_direct!( dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; - α::wT=wT(1),β=false) where {xT, yT, wT} + alpha::wT=wT(1),beta=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution @@ -186,7 +186,7 @@ function ∇depthwiseconv_filter_direct!( ) ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; - α=α, β=β) + alpha=alpha, beta=beta) dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) end return dw diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index ec4b42aaf..145dc9961 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -2,7 +2,7 @@ """ - depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); α=1, β=0) + depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a depthwise convolution using im2col and GEMM, store the result in `y`. @@ -14,7 +14,7 @@ function depthwiseconv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where T + alpha::T=T(1), beta::T=T(0)) where T check_dims(size(x), size(w), size(y), cdims) # This functions exactly the same as conv_im2col!(), except that we shard the @@ -40,7 +40,7 @@ function depthwiseconv_im2col!( col_ptr = pointer(col_slice, (c_in-1)*M*K+1) w_ptr = pointer(w, (c_in-1)*K*N+1) y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, α, col_ptr, w_ptr, β, y_ptr) + gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end end @@ -48,7 +48,7 @@ function depthwiseconv_im2col!( end """ - ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); α=1, β) + ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) Depthwise conv2d backward pass onto the weights using im2col and GEMM. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -59,7 +59,7 @@ function ∇depthwiseconv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dw, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where T + alpha::T=T(1), beta::T=T(0)) where T check_dims(size(x), size(dw), size(dy), cdims) M = prod(kernel_size(cdims)) @@ -78,19 +78,19 @@ function ∇depthwiseconv_filter_im2col!( col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1) dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) - gemm!(Val(true), Val(false), M, N, K, α, col_ptr, dy_ptr, β, dw_ptr) + gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) end end - # Because we accumulate over batches in this loop, we must set `β` equal + # Because we accumulate over batches in this loop, we must set `beta` equal # to `1.0` from this point on. - β = T(1) + beta = T(1) end return dw end """ - depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); α=1, β=0) + depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Depwthwise conv2d backward pass onto the input using im2col and GEMM. See the documentation for `conv_im2col!()` for explanation of optional parameters. @@ -101,7 +101,7 @@ function ∇depthwiseconv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), - α::T=T(1), β::T=T(0)) where T + alpha::T=T(1), beta::T=T(0)) where T check_dims(size(dx), size(w), size(dy), cdims) M = prod(output_size(cdims)) @@ -119,7 +119,7 @@ function ∇depthwiseconv_data_im2col!( dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) w_ptr = pointer(w, (cidx - 1)*K*N + 1) col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1) - gemm!(Val(false), Val(true), M, N, K, α, dy_ptr, w_ptr, T(0), col_ptr) + gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims) diff --git a/src/impl/pooling_direct.jl b/src/impl/pooling_direct.jl index e5bc838ae..f95ab32f5 100644 --- a/src/impl/pooling_direct.jl +++ b/src/impl/pooling_direct.jl @@ -5,7 +5,7 @@ using Statistics for name in (:max, :mean) @eval function $((Symbol("$(name)pool_direct!")))( y::AbstractArray{T,5}, x::AbstractArray{T,5}, - pdims::PoolDims; α::T = T(1), β::T = T(0)) where {T} + pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} check_dims(size(x), size(y), pdims) width, height, depth = input_size(pdims) @@ -24,9 +24,9 @@ for name in (:max, :mean) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling it - # into the `α` multiplier. + # into the `alpha` multiplier. if $(name == :mean) - α = α/prod(kernel_size(pdims)) + alpha = alpha/prod(kernel_size(pdims)) end # Each loop, we initialize `m` to something, set that here. @@ -66,7 +66,7 @@ for name in (:max, :mean) error("Unimplemented codegen path") end end - y[w, h, d, c, batch_idx] = α*m + β*y[w, h, d, c, batch_idx] + y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] end # Next, the padded regions @@ -111,7 +111,7 @@ for name in (:max, :mean) end end end - y[w, h, d, c, batch_idx] = α*m + β*y[w, h, d, c, batch_idx] + y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx] end end @@ -124,7 +124,7 @@ for name in (:max, :mean) @eval function $((Symbol("∇$(name)pool_direct!")))( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, y::AbstractArray{T,5}, x::AbstractArray{T,5}, - pdims::PoolDims; α::T = T(1), β::T = T(0)) where {T} + pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T} check_dims(size(x), size(dy), pdims) width, height, depth = input_size(pdims) @@ -143,9 +143,9 @@ for name in (:max, :mean) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling - # it into the `α` multiplier. + # it into the `alpha` multiplier. if $(name == :mean) - α = α/prod(kernel_size(pdims)) + alpha = alpha/prod(kernel_size(pdims)) end # Start with the central region @@ -176,15 +176,15 @@ for name in (:max, :mean) # If it's equal; this is the one we chose. We only choose one per # kernel window, all other elements of dx must be zero. if y_idx == x[x_idxs...] && !maxpool_already_chose - dx[x_idxs...] = dy_idx*α + β*dx[x_idxs...] + dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] maxpool_already_chose = true - # Maxpooling does not support `β` right now. :( + # Maxpooling does not support `beta` right now. :( #else - # dx[x_idxs...] = T(0) + β*dx[x_idxs...] + # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] end elseif $(name == :mean) # Either does meanpool :( - dx[x_idxs...] = dy_idx*α + dx[x_idxs...] + dx[x_idxs...] = dy_idx*alpha + dx[x_idxs...] else error("Unimplemented codegen path") end @@ -228,13 +228,13 @@ for name in (:max, :mean) x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) if $(name == :max) if y_idx == x[x_idxs...] && !maxpool_already_chose - dx[x_idxs...] = dy_idx*α + β*dx[x_idxs...] + dx[x_idxs...] = dy_idx*alpha + beta*dx[x_idxs...] maxpool_already_chose = true #else - # dx[x_idxs...] = T(0) + β*dx[x_idxs...] + # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] end elseif $(name == :mean) - dx[x_idxs...] += dy_idx*α + β*dx[x_idxs...] + dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...] else error("Unimplemented codegen path") end From 527eb5ebedad7d7f59315c7d9c2c3e1ab89c73c8 Mon Sep 17 00:00:00 2001 From: Adarshkumar712 Date: Mon, 6 Apr 2020 17:04:21 +0530 Subject: [PATCH 3/3] Moved Mathematical defn to docstring --- src/activation.jl | 89 ++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/src/activation.jl b/src/activation.jl index 54ea36218..a075335a6 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -7,10 +7,10 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu # https://github.com/JuliaGPU/CuArrays.jl/issues/614 """ - σ(x) = 1 / (1 + exp(-x)) + σ(x) Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation -function. +function. Return `1 / (1 + exp(-x))`. """ σ(x::Real) = one(x) / (one(x) + exp(-x)) const sigmoid = σ @@ -23,9 +23,9 @@ const sigmoid = σ end """ - hardσ(x, a=0.2) = max(0, min(1.0, a * x + 0.5)) + hardσ(x, a=0.2) -Segment-wise linear approximation of sigmoid. +Segment-wise linear approximation of sigmoid. Return `max(0, min(1.0, a * x + 0.5))`. See [BinaryConnect: Training Deep Neural Networks withbinary weights during propagations](https://arxiv.org/pdf/1511.00363.pdf). """ hardσ(x::Real, a=0.2) = oftype(x / 1, max(zero(x / 1), min(one(x / 1), oftype(x / 1, a) * x + oftype(x / 1, 0.5)))) @@ -48,46 +48,45 @@ logσ(x::Real) = -softplus(-x) const logsigmoid = logσ """ - hardtanh(x) = max(-1, min(1, x)) + hardtanh(x) -Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh. +Segment-wise linear approximation of tanh. Return `max(-1, min(1, x))`. +Cheaper and more computational efficient version of tanh. See [Large Scale Machine Learning](http://ronan.collobert.org/pub/matos/2004_phdthesis_lip6.pdf). """ hardtanh(x::Real) = max(-one(x), min( one(x), x)) """ - relu(x) = max(0, x) + relu(x) [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) -activation function. +activation function. Return `max(0, x)`. """ relu(x::Real) = max(zero(x), x) """ - leakyrelu(x, a=0.01) = max(a*x, x) + leakyrelu(x, a=0.01) Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) -activation function. +activation function. Return `max(a*x, x)`. You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`. """ leakyrelu(x::Real, a=0.01) = max(oftype(x / 1, a) * x, x / 1) """ - relu6(x) = min(max(0, x), 6) + relu6(x) [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) -activation function capped at 6. +activation function capped at 6. Return `min(max(0, x), 6)`. See [Convolutional Deep Belief Networks on CIFAR-10](http://www.cs.utoronto.ca/%7Ekriz/conv-cifar10-aug2010.pdf) """ relu6(x::Real) = min(relu(x), oftype(x, 6)) """ - rrelu(x, l=1/8, u=1/3) = max(a*x, x) - - a = randomly sampled from uniform distribution U(l, u) + rrelu(x, l=1/8, u=1/3) Randomized Leaky [Rectified Linear Unit](https://arxiv.org/pdf/1505.00853.pdf) -activation function. +activation function. Return `max(a*x, x)` where `a` is randomly sampled from uniform distribution U(l, u). You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`. """ function rrelu(x::Real, l::Real = 1 / 8.0, u::Real = 1 / 3.0) @@ -96,20 +95,19 @@ function rrelu(x::Real, l::Real = 1 / 8.0, u::Real = 1 / 3.0) end """ - elu(x, α=1) = - x > 0 ? x : α * (exp(x) - 1) + elu(x, α=1) -Exponential Linear Unit activation function. +Exponential Linear Unit activation function. Return `x > 0 ? x : α * (exp(x) - 1)`. See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289). You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. """ -elu(x::Real, α = one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x) - one(x))) +elu(x::Real, α=one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x) - one(x))) """ - gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) + gelu(x) [Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf) -activation function. +activation function. Return `0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))`. """ function gelu(x::Real) p = oftype(x / 1, π) @@ -120,28 +118,28 @@ function gelu(x::Real) end """ - swish(x) = x * σ(x) + swish(x) -Self-gated activation function. +Self-gated activation function. Return `x * σ(x)`. See [Swish: a Self-Gated Activation Function](https://arxiv.org/pdf/1710.05941.pdf). """ swish(x::Real) = x * σ(x) """ - lisht(x) = x * tanh(x) + lisht(x) -Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function. +Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function. Return `x * tanh(x)`. See [LiSHT](https://arxiv.org/abs/1901.05894) """ lisht(x::Real) = x * tanh(x) """ - selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1)) - + selu(x) + λ ≈ 1.0507 α ≈ 1.6733 -Scaled exponential linear units. +Scaled exponential linear units. Return `λ * (x ≥ 0 ? x : α * (exp(x) - 1))`. See [Self-Normalizing Neural Networks](https://arxiv.org/pdf/1706.02515.pdf). """ function selu(x::Real) @@ -151,62 +149,65 @@ function selu(x::Real) end """ - celu(x, α=1) = - (x ≥ 0 ? x : α * (exp(x/α) - 1)) + celu(x, α=1) +Return `(x ≥ 0 ? x : α * (exp(x/α) - 1))`. See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf). """ -celu(x::Real, α::Real = one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x/α) - one(x))) +celu(x::Real, α::Real=one(x)) = ifelse(x ≥ 0, x / 1, α * (exp(x/α) - one(x))) """ - trelu(x, θ=1.0) = x > θ ? x : 0 + trelu(x, θ=1.0) -Threshold Gated Rectified Linear. +Threshold Gated Rectified Linear. Return `x > θ ? x : 0`. See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf) """ -trelu(x::Real,θ = one(x)) = ifelse(x> θ, x, zero(x)) +trelu(x::Real, θ=one(x)) = ifelse(x> θ, x, zero(x)) const thresholdrelu = trelu """ - softsign(x) = x / (1 + |x|) + softsign(x) +Return `x / (1 + |x|)`. See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205). """ softsign(x::Real) = x / (one(x) + abs(x)) """ - softplus(x) = log(exp(x) + 1) + softplus(x) +Return `log(exp(x) + 1)`. See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf). """ softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) """ - logcosh(x) = x + softplus(-2x) - log(2) + logcosh(x) -Return `log(cosh(x))` which is computed in a numerically stable way. +Return `log(cosh(x))` which is computed in a numerically stable way as `x + softplus(-2x) - log(2)`. """ logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2)) """ - mish(x) = x * tanh(softplus(x)) + mish(x) -Self Regularized Non-Monotonic Neural Activation Function. +Self Regularized Non-Monotonic Neural Activation Function. Return `x * tanh(softplus(x))`. See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). """ mish(x::Real) = x * tanh(softplus(x)) """ - tanhshrink(x) = x - tanh(x) + tanhshrink(x) +Return `x - tanh(x)`. See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function). """ tanhshrink(x::Real) = x - tanh(x) """ - softshrink(x, λ=0.5) = - (x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0)) + softshrink(x, λ=0.5) +Return `(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))`. See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function). """ softshrink(x::Real, λ = oftype(x / 1, 0.5)) = min(max(zero(x), x - λ), x + λ)