Skip to content

Commit 16b7486

Browse files
Add lppool implementation (#447)
* add: init my lppool impl * add: impl gradient calculate * test: for lppool(1d & 2d) * doc: juliadoc & formula comment * doc: update for meanpool/maxpool, add params details and warning for lppool-maxpool * update: 1 => T(1); remove unnecessary if-condition; clear error msg and doc * update: move dim checker from exported pool function to PoolDims * rename: lppool => normpool * add: normpool p value checker, p must be in (0, Inf) * rename: normpool => lpnormpool * doc: add paper reference to lp pooling * doc: usage and parameter description of pool functions
1 parent 9425e50 commit 16b7486

File tree

7 files changed

+198
-14
lines changed

7 files changed

+198
-14
lines changed

docs/src/reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ logsoftmax
4444

4545
## Pooling
4646

47-
`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, and `MeanPool` use `NNlib.PoolDims`, `NNlib.maxpool`, and `NNlib.meanpool` as their backend.
47+
`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend.
4848

4949
```@docs
5050
PoolDims
5151
maxpool
5252
meanpool
53+
lpnormpool
5354
```
5455

5556
## Padding

src/NNlib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ include("ctc.jl")
7171
export ctc_loss
7272

7373
include("pooling.jl")
74-
export maxpool, maxpool!, meanpool, meanpool!,
75-
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!
74+
export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!,
75+
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool!
7676

7777
include("padding.jl")
7878
export pad_constant, pad_repeat, pad_reflect, pad_zeros

src/dim_helpers/PoolDims.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ function PoolDims(
2525
_check_kernel(k::NTuple, ::Int) = k
2626

2727
kernel = _check_kernel(k, M - 2)
28+
length(x_size) == length(kernel) + 2 || error(
29+
"PoolDims expects ndim(x) == length(k)+2 or length(size(x)) == length(kernel)+2,
30+
dimension of x_size is $(length(x_size)),
31+
length of k need $(length(x_size) - 2),
32+
but now it's $(length(kernel))"
33+
)
2834
spdf_kernel = NTuple{M, Int}([kernel..., 1, 1])
2935

3036
sstride, ppadding, ddilation = check_spdf(

src/impl/pooling_direct.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing
22
# the inner loop operation and a few initialization parameters.
3-
for name in (:max, :mean)
3+
for name in (:max, :mean, :lpnorm)
44
@eval function $((Symbol("$(name)pool_direct!")))(
55
y::AbstractArray{T, 5}, x::AbstractArray{T, 5},
6-
pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where T
6+
pdims::PoolDims; alpha::T=T(1), beta::T=T(0), kwargs...) where T
77
$((Symbol("$(name)pool_direct!")))(
88
y, x, pdims,
99
Val(kernel_size(pdims)), Val(channels_out(pdims)),
1010
Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims));
11-
alpha, beta)
11+
alpha, beta, kwargs...)
1212
return y
1313
end
1414

@@ -17,7 +17,7 @@ for name in (:max, :mean)
1717
pdims::PoolDims,
1818
# kernel size, channels out, padding, dilation, stride
1919
::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S};
20-
alpha::T=T(1), beta::T=T(0),
20+
alpha::T=T(1), beta::T=T(0), kwargs...
2121
) where {T, K, C, P, D, S}
2222
@assert beta == T(0) "beta not supported yet"
2323
check_dims(size(x), size(y), pdims)
@@ -41,10 +41,15 @@ for name in (:max, :mean)
4141
alpha = alpha / prod(K)
4242
end
4343

44+
p = if $(name != :lpnorm) 0 else
45+
!haskey(kwargs, :p) && error("lpnormpool needs keyword argument `p`")
46+
kwargs[:p]
47+
end
48+
4449
# Each loop, we initialize `m` to something, set that here.
4550
m_init = if $(name == :max)
4651
T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T)
47-
elseif $(name == :mean)
52+
elseif $(name == :mean) || $(name == :lpnorm)
4853
T(0)
4954
else
5055
error("Unimplemented codegen path")
@@ -78,11 +83,17 @@ for name in (:max, :mean)
7883
end
7984
elseif $(name == :mean)
8085
m += x[input_kw, input_kh, input_kd, c, batch_idx]
86+
elseif $(name == :lpnorm)
87+
# y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p
88+
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
8189
else
8290
error("Unimplemented codegen path")
8391
end
8492
end
8593

94+
# for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p)
95+
m = $(name == :lpnorm) ? m^(T(1) / p) : m
96+
8697
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
8798
end
8899
end
@@ -128,12 +139,15 @@ for name in (:max, :mean)
128139
end
129140
elseif $(name == :mean)
130141
m += x[input_kw, input_kh, input_kd, c, batch_idx]
142+
elseif $(name == :lpnorm)
143+
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
131144
else
132145
error("Unimplemented codegen path")
133146
end
134147
end
135148
end
136149
end
150+
$(name == :lpnorm) && (m = m^(T(1) / p))
137151
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
138152
end
139153
end
@@ -159,7 +173,7 @@ for name in (:max, :mean)
159173
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
160174
y::AbstractArray{T,5}, x::AbstractArray{T,5},
161175
pdims::PoolDims, ::Val{K}; # == kernel_size(pdims)
162-
alpha::T=T(1), beta::T=T(0)) where {T, K}
176+
alpha::T=T(1), beta::T=T(0), kwargs...) where {T, K}
163177
check_dims(size(x), size(dy), pdims)
164178

165179
width, height, depth = input_size(pdims)
@@ -182,6 +196,11 @@ for name in (:max, :mean)
182196
alpha = alpha / prod(K)
183197
end
184198

199+
p = if $(name != :lpnorm) 0 else
200+
!haskey(kwargs, :p) && error("lpnormpool must pass p")
201+
kwargs[:p]
202+
end
203+
185204
# Start with the central region
186205
w_region, h_region, d_region = central_region
187206
@inbounds for batch_idx in 1:size(x, 5), c in 1:out_c
@@ -226,6 +245,10 @@ for name in (:max, :mean)
226245
elseif $(name == :mean)
227246
# Either does meanpool :(
228247
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha
248+
elseif $(name == :lpnorm)
249+
# y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p)
250+
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
251+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
229252
else
230253
error("Unimplemented codegen path")
231254
end
@@ -286,6 +309,9 @@ for name in (:max, :mean)
286309
end
287310
elseif $(name == :mean)
288311
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
312+
elseif $(name == :lpnorm)
313+
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
314+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
289315
else
290316
error("Unimplemented codegen path")
291317
end

src/pooling.jl

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
# - maxpool!(y, x, pdims)
99
# - meanpool(x, pdims)
1010
# - meanpool!(y, x, pdims)
11+
# - lpnormpool(x, pdims)
12+
# - lpnormpool!(y, x, pdims)
1113
# - Pooling input backprop
1214
# - ∇maxpool(dy, y, x, pdims)
1315
# - ∇maxpool!(dx, dy, y, x, pdims)
1416
# - ∇meanpool(dy, y, x, pdims)
1517
# - ∇meanpool!(dx, dy, y, x pdims)
18+
# - ∇lpnormpool(dy, y, x, pdims)
19+
# - ∇lpnormpool!(dx, dy, y, x pdims)
1620
#
1721
# All methods require a `PoolDims` object to define the dimensions and optional
1822
# elements of the convolution (stride, dilation, etc...), which is easily constructable
@@ -26,6 +30,7 @@ for (front_name, backend) in (
2630
# This maps from public, front-facing name, to internal backend name
2731
:maxpool => :direct,
2832
:meanpool => :direct,
33+
:lpnormpool => :direct,
2934
)
3035

3136
# We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling
@@ -42,6 +47,7 @@ end
4247
for (front_name, backend) in (
4348
:∇maxpool => :direct,
4449
:∇meanpool => :direct,
50+
:∇lpnormpool => :direct,
4551
)
4652
@eval begin
4753
function $(Symbol("$(front_name)!"))(
@@ -57,7 +63,7 @@ end
5763
# Our strategy for pooling is to reshape to an array with three spatial dimensions, which
5864
# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,
5965
# since we can specialize on sizes.
60-
for front_name in (:maxpool, :meanpool)
66+
for front_name in (:maxpool, :meanpool, :lpnormpool)
6167
for backend in (Symbol(), :_direct)
6268
for N in (3, 4)
6369
@eval begin
@@ -103,7 +109,7 @@ end
103109
# Finally, let's generate auto-allocating versions of all our functions, for all backends:
104110
for backend in (Symbol(), :_direct, :_nnpack)
105111
# First make auto-allocating versions of the basic pooling calls:
106-
for name in (:maxpool, :meanpool)
112+
for name in (:maxpool, :meanpool, :lpnormpool)
107113
@eval begin
108114
function $(Symbol("$(name)$(backend)"))(
109115
x::AbstractArray{xT,N},
@@ -141,9 +147,15 @@ expand(N, i::Integer) = ntuple(_ -> i, N)
141147

142148

143149
"""
144-
maxpool(x, k::NTuple; pad=0, stride=k)
150+
maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k)
145151
146152
Perform max pool operation with window size `k` on input tensor `x`.
153+
154+
Arguments:
155+
156+
* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2`
157+
* `pad`: See [`pad_zeros`](@ref) for details.
158+
* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.
147159
"""
148160
function maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
149161
pad = expand(Val(N), pad)
@@ -154,9 +166,15 @@ end
154166

155167

156168
"""
157-
meanpool(x, k::NTuple; pad=0, stride=k)
169+
meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k)
158170
159171
Perform mean pool operation with window size `k` on input tensor `x`.
172+
173+
Arguments:
174+
175+
* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2`
176+
* `pad`: See [`pad_zeros`](@ref) for details.
177+
* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.
160178
"""
161179
function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
162180
pad = expand(Val(N), pad)
@@ -166,7 +184,33 @@ function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N
166184
end
167185

168186

169-
for pool in [:maxpool, :meanpool]
187+
"""
188+
lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k)
189+
190+
Perform Lp pool operation with value of the Lp norm `p` and window size `k` on input tensor `x`, also known as LPPool in pytorch.
191+
This pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks](https://arxiv.org/abs/1311.1780).
192+
193+
Arguments:
194+
195+
* `x` and `k`: Expects `ndim(x) ∈ 3:5``, and always `length(k) == ndim(x) - 2`
196+
* `p` is restricted to `0 < p < Inf`.
197+
* `pad`: See [`pad_zeros`](@ref) for details.
198+
* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.
199+
200+
For all elements `x` in a size `k` window, lpnormpool computes `(∑ᵢ xᵢ^p)^(1 / p)` as an element of the output.
201+
202+
Thus `lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)` and `lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)`.
203+
"""
204+
function lpnormpool(x, p::Number, k::NTuple{N, Integer}; pad=0, stride=k) where N
205+
(isinf(p) || p < 0) && error("p value of Lp norm pool expects `0 < p < Inf`, but p is $(p) now.")
206+
pad = expand(Val(N), pad)
207+
stride = expand(Val(N), stride)
208+
pdims = PoolDims(x, k; padding=pad, stride=stride)
209+
return lpnormpool(x, pdims; p=p)
210+
end
211+
212+
213+
for pool in [:maxpool, :meanpool, :lpnormpool]
170214
∇pool = Symbol(:∇, pool)
171215
pullback = Symbol(pool, :_pullback)
172216
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)

test/perf/perf_report.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ for rank in (2,),
9393
for (pool, ∇pool, name) in (
9494
(NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
9595
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
96+
(NNlib.lpnormpool!, NNlib.∇lpnormpool!, "lpnormpool"),
9697
)
9798

9899
t_fwd = @benchmark $(pool)( $y, $x, $pdims)

0 commit comments

Comments
 (0)