diff --git a/docs/src/reference.md b/docs/src/reference.md index f4ae5fdbf..c790c806a 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -44,12 +44,13 @@ logsoftmax ## Pooling -`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, and `MeanPool` use `NNlib.PoolDims`, `NNlib.maxpool`, and `NNlib.meanpool` as their backend. +`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend. ```@docs PoolDims maxpool meanpool +lpnormpool ``` ## Padding diff --git a/src/NNlib.jl b/src/NNlib.jl index 19658d2be..c406cb8a6 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -67,8 +67,8 @@ include("ctc.jl") export ctc_loss include("pooling.jl") -export maxpool, maxpool!, meanpool, meanpool!, - ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool! +export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!, + ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool! include("padding.jl") export pad_constant, pad_repeat, pad_reflect, pad_zeros diff --git a/src/dim_helpers/PoolDims.jl b/src/dim_helpers/PoolDims.jl index e7219ff72..75d56b8cd 100644 --- a/src/dim_helpers/PoolDims.jl +++ b/src/dim_helpers/PoolDims.jl @@ -25,6 +25,12 @@ function PoolDims( _check_kernel(k::NTuple, ::Int) = k kernel = _check_kernel(k, M - 2) + length(x_size) == length(kernel) + 2 || error( + "PoolDims expects ndim(x) == length(k)+2 or length(size(x)) == length(kernel)+2, + dimension of x_size is $(length(x_size)), + length of k need $(length(x_size) - 2), + but now it's $(length(kernel))" + ) spdf_kernel = NTuple{M, Int}([kernel..., 1, 1]) sstride, ppadding, ddilation = check_spdf( diff --git a/src/impl/pooling_direct.jl b/src/impl/pooling_direct.jl index 566406eb2..55c1b9b23 100644 --- a/src/impl/pooling_direct.jl +++ b/src/impl/pooling_direct.jl @@ -1,14 +1,14 @@ # Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing # the inner loop operation and a few initialization parameters. -for name in (:max, :mean) +for name in (:max, :mean, :lpnorm) @eval function $((Symbol("$(name)pool_direct!")))( y::AbstractArray{T, 5}, x::AbstractArray{T, 5}, - pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where T + pdims::PoolDims; alpha::T=T(1), beta::T=T(0), kwargs...) where T $((Symbol("$(name)pool_direct!")))( y, x, pdims, Val(kernel_size(pdims)), Val(channels_out(pdims)), Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims)); - alpha, beta) + alpha, beta, kwargs...) return y end @@ -17,7 +17,7 @@ for name in (:max, :mean) pdims::PoolDims, # kernel size, channels out, padding, dilation, stride ::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S}; - alpha::T=T(1), beta::T=T(0), + alpha::T=T(1), beta::T=T(0), kwargs... ) where {T, K, C, P, D, S} @assert beta == T(0) "beta not supported yet" check_dims(size(x), size(y), pdims) @@ -41,10 +41,15 @@ for name in (:max, :mean) alpha = alpha / prod(K) end + p = if $(name != :lpnorm) 0 else + !haskey(kwargs, :p) && error("lpnormpool needs keyword argument `p`") + kwargs[:p] + end + # Each loop, we initialize `m` to something, set that here. m_init = if $(name == :max) T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T) - elseif $(name == :mean) + elseif $(name == :mean) || $(name == :lpnorm) T(0) else error("Unimplemented codegen path") @@ -78,11 +83,17 @@ for name in (:max, :mean) end elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] + elseif $(name == :lpnorm) + # y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p + m += x[input_kw, input_kh, input_kd, c, batch_idx]^p else error("Unimplemented codegen path") end end + # for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p) + m = $(name == :lpnorm) ? m^(T(1) / p) : m + y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx] end end @@ -128,12 +139,15 @@ for name in (:max, :mean) end elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] + elseif $(name == :lpnorm) + m += x[input_kw, input_kh, input_kd, c, batch_idx]^p else error("Unimplemented codegen path") end end end end + $(name == :lpnorm) && (m = m^(T(1) / p)) y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx] end end @@ -159,7 +173,7 @@ for name in (:max, :mean) dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, y::AbstractArray{T,5}, x::AbstractArray{T,5}, pdims::PoolDims, ::Val{K}; # == kernel_size(pdims) - alpha::T=T(1), beta::T=T(0)) where {T, K} + alpha::T=T(1), beta::T=T(0), kwargs...) where {T, K} check_dims(size(x), size(dy), pdims) width, height, depth = input_size(pdims) @@ -182,6 +196,11 @@ for name in (:max, :mean) alpha = alpha / prod(K) end + p = if $(name != :lpnorm) 0 else + !haskey(kwargs, :p) && error("lpnormpool must pass p") + kwargs[:p] + end + # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch_idx in 1:size(x, 5), c in 1:out_c @@ -226,6 +245,10 @@ for name in (:max, :mean) elseif $(name == :mean) # Either does meanpool :( dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha + elseif $(name == :lpnorm) + # y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p) + grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) + dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path") end @@ -286,6 +309,9 @@ for name in (:max, :mean) end elseif $(name == :mean) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...] + elseif $(name == :lpnorm) + grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) + dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path") end diff --git a/src/pooling.jl b/src/pooling.jl index f13390342..704d7d0fa 100644 --- a/src/pooling.jl +++ b/src/pooling.jl @@ -8,11 +8,15 @@ # - maxpool!(y, x, pdims) # - meanpool(x, pdims) # - meanpool!(y, x, pdims) +# - lpnormpool(x, pdims) +# - lpnormpool!(y, x, pdims) # - Pooling input backprop # - ∇maxpool(dy, y, x, pdims) # - ∇maxpool!(dx, dy, y, x, pdims) # - ∇meanpool(dy, y, x, pdims) # - ∇meanpool!(dx, dy, y, x pdims) +# - ∇lpnormpool(dy, y, x, pdims) +# - ∇lpnormpool!(dx, dy, y, x pdims) # # All methods require a `PoolDims` object to define the dimensions and optional # elements of the convolution (stride, dilation, etc...), which is easily constructable @@ -26,6 +30,7 @@ for (front_name, backend) in ( # This maps from public, front-facing name, to internal backend name :maxpool => :direct, :meanpool => :direct, + :lpnormpool => :direct, ) # We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling @@ -42,6 +47,7 @@ end for (front_name, backend) in ( :∇maxpool => :direct, :∇meanpool => :direct, + :∇lpnormpool => :direct, ) @eval begin function $(Symbol("$(front_name)!"))( @@ -57,7 +63,7 @@ end # Our strategy for pooling is to reshape to an array with three spatial dimensions, which # makes things MUCH EASIER for us on the backend side, and is in general pretty fast, # since we can specialize on sizes. -for front_name in (:maxpool, :meanpool) +for front_name in (:maxpool, :meanpool, :lpnormpool) for backend in (Symbol(), :_direct) for N in (3, 4) @eval begin @@ -103,7 +109,7 @@ end # Finally, let's generate auto-allocating versions of all our functions, for all backends: for backend in (Symbol(), :_direct, :_nnpack) # First make auto-allocating versions of the basic pooling calls: - for name in (:maxpool, :meanpool) + for name in (:maxpool, :meanpool, :lpnormpool) @eval begin function $(Symbol("$(name)$(backend)"))( x::AbstractArray{xT,N}, @@ -141,9 +147,15 @@ expand(N, i::Integer) = ntuple(_ -> i, N) """ - maxpool(x, k::NTuple; pad=0, stride=k) + maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) Perform max pool operation with window size `k` on input tensor `x`. + +Arguments: + +* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` +* `pad`: See [`pad_zeros`](@ref) for details. +* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. """ function maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N pad = expand(Val(N), pad) @@ -154,9 +166,15 @@ end """ - meanpool(x, k::NTuple; pad=0, stride=k) + meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) Perform mean pool operation with window size `k` on input tensor `x`. + +Arguments: + +* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2` +* `pad`: See [`pad_zeros`](@ref) for details. +* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. """ function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N pad = expand(Val(N), pad) @@ -166,7 +184,33 @@ function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N end -for pool in [:maxpool, :meanpool] +""" + lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) + +Perform Lp pool operation with value of the Lp norm `p` and window size `k` on input tensor `x`, also known as LPPool in pytorch. +This pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks](https://arxiv.org/abs/1311.1780). + +Arguments: + +* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2` +* `p` is restricted to `0 < p < Inf`. +* `pad`: See [`pad_zeros`](@ref) for details. +* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. + +For all elements `x` in a size `k` window, lpnormpool computes `(∑ᵢ xᵢ^p)^(1 / p)` as an element of the output. + +Thus `lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)` and `lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)`. +""" +function lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) where N + (isinf(p) || p < 0) && error("p value of Lp norm pool expects `0 < p < Inf`, but p is $(p) now.") + pad = expand(Val(N), pad) + stride = expand(Val(N), stride) + pdims = PoolDims(x, k; padding=pad, stride=stride) + return lpnormpool(x, pdims; p=p) +end + + +for pool in [:maxpool, :meanpool, :lpnormpool] ∇pool = Symbol(:∇, pool) pullback = Symbol(pool, :_pullback) @eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...) diff --git a/test/perf/perf_report.jl b/test/perf/perf_report.jl index b1010f416..5c06515eb 100644 --- a/test/perf/perf_report.jl +++ b/test/perf/perf_report.jl @@ -93,6 +93,7 @@ for rank in (2,), for (pool, ∇pool, name) in ( (NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"), (NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"), + (NNlib.lpnormpool!, NNlib.∇lpnormpool!, "lpnormpool"), ) t_fwd = @benchmark $(pool)( $y, $x, $pdims) diff --git a/test/pooling.jl b/test/pooling.jl index d1d26c620..b7952295f 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -248,6 +248,70 @@ meanpool_answer_dict = Dict( ) ) +lpnormpool_answer_dict = Dict( + 1 => Dict( + "y" => [2.019312856150994, 4.221163518110637], + "y_nostride" => [ + 2.080083823051904, 3.2710663101885897, + 4.497941445275415, 5.738793548317167 + ], + "y_pad" => [1.0, 3.605551275463989, 6.4031242374328485], + "dx" => [ + 0.17258020254042603, 1.9525221042381296, + 1.2774501198988355, 3.496467771732918, 0.0 + ], + "dx_nostride" => [ + 0.48074985676913606, 3.1458422620080637, + 4.752311710531486, 6.345225258061685, 4.356316321455918 + ], + "dx_pad" => [1.0, 2.0, 3.0, 4.0, 5.0], + "p" => 4.5, + "p_nostride" => 3.0, + "p_pad" => 2.0 + ), + 2 => Dict( + "y" => [ + 8.71909 24.9703; + 11.7336 28.3804 + ], + "y_nostride" => [ + 11.1128 23.134 35.5704; + 13.4219 25.6082 38.0707; + 15.8033 28.0907 40.5735; + 18.2249 30.5795 43.0782 + ], + "y_pad" => [ + 1.0 11.3616 16.0; + 3.19158 15.9662 21.3545; + 5.56869 18.7771 23.7903 + ], + "dx" => [ + 0.33866 4.97727 7.30092 12.8076; + 0.957876 6.27208 8.31879 14.0269; + 1.51693 6.6057 8.79844 14.3351; + 2.33547 7.8822 9.83293 15.5461; + 0.0 0.0 0.0 0.0 + ], + "dx_nostride" => [ + 3.33359 19.9471 35.7329 23.8564; + 9.89551 44.627 76.2257 50.0307; + 13.231 50.9101 82.5686 53.2022; + 16.4888 57.223 88.9133 56.3742; + 9.54591 30.9869 46.8371 29.3524 + ], + "dx_pad" => [ + 1.0 2.30261 10.4791 16.0; + 0.992125 2.0321 7.81903 12.075; + 2.73398 2.83743 9.5512 13.9299; + 2.43512 2.98652 9.0132 13.5608; + 4.25398 3.8865 10.7099 15.4161 + ], + "p" => 2.5, + "p_nostride" => 1.5, + "p_pad" => 3.5 + ) +) + for rank in (1, 2, 3) @testset "pool$(rank)d" begin for (pool, ∇pool, answer_dict) in ( @@ -297,6 +361,48 @@ for rank in (1, 2, 3) end end +for rank in (1, 2) + for (pool, ∇pool, answer_dict) in ( + (lpnormpool, ∇lpnormpool, lpnormpool_answer_dict), + (NNlib.lpnormpool_direct, NNlib.∇lpnormpool_direct, lpnormpool_answer_dict),) + @testset "$(pool)$(rank)d" begin + y = answer_dict[rank]["y"] + y_nostride = answer_dict[rank]["y_nostride"] + y_pad = answer_dict[rank]["y_pad"] + dx = answer_dict[rank]["dx"] + dx_nostride = answer_dict[rank]["dx_nostride"] + dx_pad = answer_dict[rank]["dx_pad"] + p = answer_dict[rank]["p"] + p_nostride = answer_dict[rank]["p_nostride"] + p_pad = answer_dict[rank]["p_pad"] + + x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) + + ddims(x) = dropdims(x, dims=(rank + 1, rank + 2)) + + @test pool(x, PoolDims(x, 1); p=p) ≈ x atol = 1e-3 + + # Test vanilla pooling + pdims = PoolDims(x, 2) + y_hat = pool(x, pdims; p=p) + @test ddims(y_hat) ≈ y atol = 1e-3 + @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p)) ≈ dx atol = 1e-3 + + # Strided pooling + pdims = PoolDims(x, 2; stride=1) + y_hat = pool(x, pdims; p=p_nostride) + @test ddims(y_hat) ≈ y_nostride atol = 1e-3 + @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_nostride)) ≈ dx_nostride atol = 1e-3 + + # Padded pooling + pdims = PoolDims(x, 2; padding=1) + y_hat = pool(x, pdims; p=p_pad) + @test ddims(y_hat) ≈ y_pad atol = 1e-3 + @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_pad)) ≈ dx_pad atol = 1e-3 + end + end +end + @testset "Pooling - Check Sizes" begin x = rand(10, 10, 3, 10) @test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)