Skip to content

Commit d79811a

Browse files
Adding return state option to recurrent layers (#2557)
1 parent 4eb4454 commit d79811a

File tree

2 files changed

+169
-33
lines changed

2 files changed

+169
-33
lines changed

src/layers/recurrent.jl

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function scan(cell, x, state)
44
yt, state = cell(x_t, state)
55
y = vcat(y, [yt])
66
end
7-
return stack(y, dims = 2)
7+
return stack(y, dims = 2), state
88
end
99

1010
"""
@@ -58,16 +58,27 @@ julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
5858
julia> y = rnn(x); # out x len x batch_size
5959
```
6060
"""
61-
struct Recurrence{M}
61+
struct Recurrence{S,M}
6262
cell::M
6363
end
6464

6565
@layer Recurrence
6666

6767
initialstates(rnn::Recurrence) = initialstates(rnn.cell)
6868

69+
function Recurrence(cell; return_state = false)
70+
return Recurrence{return_state, typeof(cell)}(cell)
71+
end
72+
6973
(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
70-
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)
74+
75+
function (rnn::Recurrence{false})(x::AbstractArray, state)
76+
first(scan(rnn.cell, x, state))
77+
end
78+
79+
function (rnn::Recurrence{true})(x::AbstractArray, state)
80+
scan(rnn.cell, x, state)
81+
end
7182

7283
# Vanilla RNN
7384
@doc raw"""
@@ -193,8 +204,8 @@ function Base.show(io::IO, m::RNNCell)
193204
end
194205

195206
@doc raw"""
196-
RNN(in => out, σ = tanh; init_kernel = glorot_uniform,
197-
init_recurrent_kernel = glorot_uniform, bias = true)
207+
RNN(in => out, σ = tanh; return_state = false,
208+
init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true)
198209
199210
The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
200211
output fed back into the input each time step.
@@ -212,6 +223,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.
212223
213224
- `in => out`: The input and output dimensions of the layer.
214225
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
226+
- `return_state`: Option to return the last state together with the output. Default is `false`.
215227
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
216228
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
217229
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -227,7 +239,8 @@ The arguments of the forward pass are:
227239
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
228240
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
229241
230-
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
242+
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
243+
a tuple of the hidden stats `h_t` and the last state of the iteration.
231244
232245
# Examples
233246
@@ -260,26 +273,43 @@ Flux.@layer Model
260273
model = Model(RNN(32 => 64), zeros(Float32, 64))
261274
```
262275
"""
263-
struct RNN{M}
276+
struct RNN{S,M}
264277
cell::M
265278
end
266279

267280
@layer :noexpand RNN
268281

269282
initialstates(rnn::RNN) = initialstates(rnn.cell)
270283

271-
function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
284+
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
272285
cell = RNNCell(in => out, σ; cell_kwargs...)
273-
return RNN(cell)
286+
return RNN{return_state, typeof(cell)}(cell)
287+
end
288+
289+
function RNN(cell::RNNCell; return_state::Bool=false)
290+
RNN{return_state, typeof(cell)}(cell)
274291
end
275292

276293
(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))
277294

278-
function (m::RNN)(x::AbstractArray, h)
295+
function (rnn::RNN{false})(x::AbstractArray, h)
279296
@assert ndims(x) == 2 || ndims(x) == 3
280297
# [x] = [in, L] or [in, L, B]
281298
# [h] = [out] or [out, B]
282-
return scan(m.cell, x, h)
299+
return first(scan(rnn.cell, x, h))
300+
end
301+
302+
function (rnn::RNN{true})(x::AbstractArray, h)
303+
@assert ndims(x) == 2 || ndims(x) == 3
304+
# [x] = [in, L] or [in, L, B]
305+
# [h] = [out] or [out, B]
306+
return scan(rnn.cell, x, h)
307+
end
308+
309+
function Functors.functor(rnn::RNN{S}) where {S}
310+
params = (cell = rnn.cell,)
311+
reconstruct = p -> RNN{S, typeof(p.cell)}(p.cell)
312+
return params, reconstruct
283313
end
284314

285315
function Base.show(io::IO, m::RNN)
@@ -391,7 +421,7 @@ Base.show(io::IO, m::LSTMCell) =
391421

392422

393423
@doc raw"""
394-
LSTM(in => out; init_kernel = glorot_uniform,
424+
LSTM(in => out; return_state = false, init_kernel = glorot_uniform,
395425
init_recurrent_kernel = glorot_uniform, bias = true)
396426
397427
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
@@ -415,6 +445,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.
415445
# Arguments
416446
417447
- `in => out`: The input and output dimensions of the layer.
448+
- `return_state`: Option to return the last state together with the output. Default is `false`.
418449
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
419450
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
420451
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -430,7 +461,8 @@ The arguments of the forward pass are:
430461
They should be vectors of size `out` or matrices of size `out x batch_size`.
431462
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
432463
433-
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
464+
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. When `return_state = true` it returns
465+
a tuple of the hidden stats `h_t` and the last state of the iteration.
434466
435467
# Examples
436468
@@ -452,24 +484,39 @@ h = model(x)
452484
size(h) # out x len x batch_size
453485
```
454486
"""
455-
struct LSTM{M}
487+
struct LSTM{S,M}
456488
cell::M
457489
end
458490

459491
@layer :noexpand LSTM
460492

461493
initialstates(lstm::LSTM) = initialstates(lstm.cell)
462494

463-
function LSTM((in, out)::Pair; cell_kwargs...)
495+
function LSTM((in, out)::Pair; return_state = false, cell_kwargs...)
464496
cell = LSTMCell(in => out; cell_kwargs...)
465-
return LSTM(cell)
497+
return LSTM{return_state, typeof(cell)}(cell)
498+
end
499+
500+
function LSTM(cell::LSTMCell; return_state::Bool=false)
501+
LSTM{return_state, typeof(cell)}(cell)
466502
end
467503

468504
(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))
469505

470-
function (m::LSTM)(x::AbstractArray, state0)
506+
function (lstm::LSTM{false})(x::AbstractArray, state0)
471507
@assert ndims(x) == 2 || ndims(x) == 3
472-
return scan(m.cell, x, state0)
508+
return first(scan(lstm.cell, x, state0))
509+
end
510+
511+
function (lstm::LSTM{true})(x::AbstractArray, state0)
512+
@assert ndims(x) == 2 || ndims(x) == 3
513+
return scan(lstm.cell, x, state0)
514+
end
515+
516+
function Functors.functor(lstm::LSTM{S}) where {S}
517+
params = (cell = lstm.cell,)
518+
reconstruct = p -> LSTM{S, typeof(p.cell)}(p.cell)
519+
return params, reconstruct
473520
end
474521

475522
function Base.show(io::IO, m::LSTM)
@@ -578,7 +625,7 @@ Base.show(io::IO, m::GRUCell) =
578625
print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")")
579626

580627
@doc raw"""
581-
GRU(in => out; init_kernel = glorot_uniform,
628+
GRU(in => out; return_state = false, init_kernel = glorot_uniform,
582629
init_recurrent_kernel = glorot_uniform, bias = true)
583630
584631
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
@@ -599,6 +646,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step.
599646
# Arguments
600647
601648
- `in => out`: The input and output dimensions of the layer.
649+
- `return_state`: Option to return the last state together with the output. Default is `false`.
602650
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
603651
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
604652
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -613,7 +661,8 @@ The arguments of the forward pass are:
613661
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
614662
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
615663
616-
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
664+
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
665+
a tuple of the hidden stats `h_t` and the last state of the iteration.
617666
618667
# Examples
619668
@@ -625,24 +674,39 @@ h0 = zeros(Float32, d_out)
625674
h = gru(x, h0) # out x len x batch_size
626675
```
627676
"""
628-
struct GRU{M}
677+
struct GRU{S,M}
629678
cell::M
630679
end
631680

632681
@layer :noexpand GRU
633682

634683
initialstates(gru::GRU) = initialstates(gru.cell)
635684

636-
function GRU((in, out)::Pair; cell_kwargs...)
685+
function GRU((in, out)::Pair; return_state = false, cell_kwargs...)
637686
cell = GRUCell(in => out; cell_kwargs...)
638-
return GRU(cell)
687+
return GRU{return_state, typeof(cell)}(cell)
688+
end
689+
690+
function GRU(cell::GRUCell; return_state::Bool=false)
691+
GRU{return_state, typeof(cell)}(cell)
639692
end
640693

641694
(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))
642695

643-
function (m::GRU)(x::AbstractArray, h)
696+
function (gru::GRU{false})(x::AbstractArray, h)
697+
@assert ndims(x) == 2 || ndims(x) == 3
698+
return first(scan(gru.cell, x, h))
699+
end
700+
701+
function (gru::GRU{true})(x::AbstractArray, h)
644702
@assert ndims(x) == 2 || ndims(x) == 3
645-
return scan(m.cell, x, h)
703+
return scan(gru.cell, x, h)
704+
end
705+
706+
function Functors.functor(gru::GRU{S}) where {S}
707+
params = (cell = gru.cell,)
708+
reconstruct = p -> GRU{S, typeof(p.cell)}(p.cell)
709+
return params, reconstruct
646710
end
647711

648712
function Base.show(io::IO, m::GRU)
@@ -739,7 +803,7 @@ Base.show(io::IO, m::GRUv3Cell) =
739803

740804

741805
@doc raw"""
742-
GRUv3(in => out; init_kernel = glorot_uniform,
806+
GRUv3(in => out; return_state = false, init_kernel = glorot_uniform,
743807
init_recurrent_kernel = glorot_uniform, bias = true)
744808
745809
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
@@ -764,6 +828,7 @@ but only a less popular variant.
764828
# Arguments
765829
766830
- `in => out`: The input and output dimensions of the layer.
831+
- `return_state`: Option to return the last state together with the output. Default is `false`.
767832
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
768833
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
769834
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -778,7 +843,8 @@ The arguments of the forward pass are:
778843
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
779844
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
780845
781-
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
846+
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
847+
a tuple of the hidden stats `h_t` and the last state of the iteration.
782848
783849
# Examples
784850
@@ -790,24 +856,39 @@ h0 = zeros(Float32, d_out)
790856
h = gruv3(x, h0) # out x len x batch_size
791857
```
792858
"""
793-
struct GRUv3{M}
859+
struct GRUv3{S,M}
794860
cell::M
795861
end
796862

797863
@layer :noexpand GRUv3
798864

799865
initialstates(gru::GRUv3) = initialstates(gru.cell)
800866

801-
function GRUv3((in, out)::Pair; cell_kwargs...)
867+
function GRUv3((in, out)::Pair; return_state = false, cell_kwargs...)
802868
cell = GRUv3Cell(in => out; cell_kwargs...)
803-
return GRUv3(cell)
869+
return GRUv3{return_state, typeof(cell)}(cell)
870+
end
871+
872+
function GRUv3(cell::GRUv3Cell; return_state::Bool=false)
873+
GRUv3{return_state, typeof(cell)}(cell)
804874
end
805875

806876
(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))
807877

808-
function (m::GRUv3)(x::AbstractArray, h)
878+
function (gru::GRUv3{false})(x::AbstractArray, h)
809879
@assert ndims(x) == 2 || ndims(x) == 3
810-
return scan(m.cell, x, h)
880+
return first(scan(gru.cell, x, h))
881+
end
882+
883+
function (gru::GRUv3{true})(x::AbstractArray, h)
884+
@assert ndims(x) == 2 || ndims(x) == 3
885+
return scan(gru.cell, x, h)
886+
end
887+
888+
function Functors.functor(gru::GRUv3{S}) where {S}
889+
params = (cell = gru.cell,)
890+
reconstruct = p -> GRUv3{S, typeof(p.cell)}(p.cell)
891+
return params, reconstruct
811892
end
812893

813894
function Base.show(io::IO, m::GRUv3)

0 commit comments

Comments
 (0)