@@ -4,7 +4,7 @@ function scan(cell, x, state)
4
4
yt, state = cell (x_t, state)
5
5
y = vcat (y, [yt])
6
6
end
7
- return stack (y, dims = 2 )
7
+ return stack (y, dims = 2 ), state
8
8
end
9
9
10
10
"""
@@ -58,16 +58,27 @@ julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
58
58
julia> y = rnn(x); # out x len x batch_size
59
59
```
60
60
"""
61
- struct Recurrence{M}
61
+ struct Recurrence{S, M}
62
62
cell:: M
63
63
end
64
64
65
65
@layer Recurrence
66
66
67
67
initialstates (rnn:: Recurrence ) = initialstates (rnn. cell)
68
68
69
+ function Recurrence (cell; return_state = false )
70
+ return Recurrence {return_state, typeof(cell)} (cell)
71
+ end
72
+
69
73
(rnn:: Recurrence )(x:: AbstractArray ) = rnn (x, initialstates (rnn))
70
- (rnn:: Recurrence )(x:: AbstractArray , state) = scan (rnn. cell, x, state)
74
+
75
+ function (rnn:: Recurrence{false} )(x:: AbstractArray , state)
76
+ first (scan (rnn. cell, x, state))
77
+ end
78
+
79
+ function (rnn:: Recurrence{true} )(x:: AbstractArray , state)
80
+ scan (rnn. cell, x, state)
81
+ end
71
82
72
83
# Vanilla RNN
73
84
@doc raw """
@@ -193,8 +204,8 @@ function Base.show(io::IO, m::RNNCell)
193
204
end
194
205
195
206
@doc raw """
196
- RNN(in => out, σ = tanh; init_kernel = glorot_uniform,
197
- init_recurrent_kernel = glorot_uniform, bias = true)
207
+ RNN(in => out, σ = tanh; return_state = false,
208
+ init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true)
198
209
199
210
The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
200
211
output fed back into the input each time step.
@@ -212,6 +223,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.
212
223
213
224
- `in => out`: The input and output dimensions of the layer.
214
225
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
226
+ - `return_state`: Option to return the last state together with the output. Default is `false`.
215
227
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
216
228
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
217
229
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -227,7 +239,8 @@ The arguments of the forward pass are:
227
239
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
228
240
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
229
241
230
- Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
242
+ Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
243
+ a tuple of the hidden stats `h_t` and the last state of the iteration.
231
244
232
245
# Examples
233
246
@@ -260,26 +273,43 @@ Flux.@layer Model
260
273
model = Model(RNN(32 => 64), zeros(Float32, 64))
261
274
```
262
275
"""
263
- struct RNN{M}
276
+ struct RNN{S, M}
264
277
cell:: M
265
278
end
266
279
267
280
@layer :noexpand RNN
268
281
269
282
initialstates (rnn:: RNN ) = initialstates (rnn. cell)
270
283
271
- function RNN ((in, out):: Pair , σ = tanh; cell_kwargs... )
284
+ function RNN ((in, out):: Pair , σ = tanh; return_state = false , cell_kwargs... )
272
285
cell = RNNCell (in => out, σ; cell_kwargs... )
273
- return RNN (cell)
286
+ return RNN {return_state, typeof(cell)} (cell)
287
+ end
288
+
289
+ function RNN (cell:: RNNCell ; return_state:: Bool = false )
290
+ RNN {return_state, typeof(cell)} (cell)
274
291
end
275
292
276
293
(rnn:: RNN )(x:: AbstractArray ) = rnn (x, initialstates (rnn))
277
294
278
- function (m :: RNN )(x:: AbstractArray , h)
295
+ function (rnn :: RNN{false} )(x:: AbstractArray , h)
279
296
@assert ndims (x) == 2 || ndims (x) == 3
280
297
# [x] = [in, L] or [in, L, B]
281
298
# [h] = [out] or [out, B]
282
- return scan (m. cell, x, h)
299
+ return first (scan (rnn. cell, x, h))
300
+ end
301
+
302
+ function (rnn:: RNN{true} )(x:: AbstractArray , h)
303
+ @assert ndims (x) == 2 || ndims (x) == 3
304
+ # [x] = [in, L] or [in, L, B]
305
+ # [h] = [out] or [out, B]
306
+ return scan (rnn. cell, x, h)
307
+ end
308
+
309
+ function Functors. functor (rnn:: RNN{S} ) where {S}
310
+ params = (cell = rnn. cell,)
311
+ reconstruct = p -> RNN {S, typeof(p.cell)} (p. cell)
312
+ return params, reconstruct
283
313
end
284
314
285
315
function Base. show (io:: IO , m:: RNN )
@@ -391,7 +421,7 @@ Base.show(io::IO, m::LSTMCell) =
391
421
392
422
393
423
@doc raw """
394
- LSTM(in => out; init_kernel = glorot_uniform,
424
+ LSTM(in => out; return_state = false, init_kernel = glorot_uniform,
395
425
init_recurrent_kernel = glorot_uniform, bias = true)
396
426
397
427
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
@@ -415,6 +445,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.
415
445
# Arguments
416
446
417
447
- `in => out`: The input and output dimensions of the layer.
448
+ - `return_state`: Option to return the last state together with the output. Default is `false`.
418
449
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
419
450
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
420
451
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -430,7 +461,8 @@ The arguments of the forward pass are:
430
461
They should be vectors of size `out` or matrices of size `out x batch_size`.
431
462
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
432
463
433
- Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
464
+ Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. When `return_state = true` it returns
465
+ a tuple of the hidden stats `h_t` and the last state of the iteration.
434
466
435
467
# Examples
436
468
@@ -452,24 +484,39 @@ h = model(x)
452
484
size(h) # out x len x batch_size
453
485
```
454
486
"""
455
- struct LSTM{M}
487
+ struct LSTM{S, M}
456
488
cell:: M
457
489
end
458
490
459
491
@layer :noexpand LSTM
460
492
461
493
initialstates (lstm:: LSTM ) = initialstates (lstm. cell)
462
494
463
- function LSTM ((in, out):: Pair ; cell_kwargs... )
495
+ function LSTM ((in, out):: Pair ; return_state = false , cell_kwargs... )
464
496
cell = LSTMCell (in => out; cell_kwargs... )
465
- return LSTM (cell)
497
+ return LSTM {return_state, typeof(cell)} (cell)
498
+ end
499
+
500
+ function LSTM (cell:: LSTMCell ; return_state:: Bool = false )
501
+ LSTM {return_state, typeof(cell)} (cell)
466
502
end
467
503
468
504
(lstm:: LSTM )(x:: AbstractArray ) = lstm (x, initialstates (lstm))
469
505
470
- function (m :: LSTM )(x:: AbstractArray , state0)
506
+ function (lstm :: LSTM{false} )(x:: AbstractArray , state0)
471
507
@assert ndims (x) == 2 || ndims (x) == 3
472
- return scan (m. cell, x, state0)
508
+ return first (scan (lstm. cell, x, state0))
509
+ end
510
+
511
+ function (lstm:: LSTM{true} )(x:: AbstractArray , state0)
512
+ @assert ndims (x) == 2 || ndims (x) == 3
513
+ return scan (lstm. cell, x, state0)
514
+ end
515
+
516
+ function Functors. functor (lstm:: LSTM{S} ) where {S}
517
+ params = (cell = lstm. cell,)
518
+ reconstruct = p -> LSTM {S, typeof(p.cell)} (p. cell)
519
+ return params, reconstruct
473
520
end
474
521
475
522
function Base. show (io:: IO , m:: LSTM )
@@ -578,7 +625,7 @@ Base.show(io::IO, m::GRUCell) =
578
625
print (io, " GRUCell(" , size (m. Wi, 2 ), " => " , size (m. Wi, 1 ) ÷ 3 , " )" )
579
626
580
627
@doc raw """
581
- GRU(in => out; init_kernel = glorot_uniform,
628
+ GRU(in => out; return_state = false, init_kernel = glorot_uniform,
582
629
init_recurrent_kernel = glorot_uniform, bias = true)
583
630
584
631
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
@@ -599,6 +646,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step.
599
646
# Arguments
600
647
601
648
- `in => out`: The input and output dimensions of the layer.
649
+ - `return_state`: Option to return the last state together with the output. Default is `false`.
602
650
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
603
651
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
604
652
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -613,7 +661,8 @@ The arguments of the forward pass are:
613
661
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
614
662
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
615
663
616
- Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
664
+ Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
665
+ a tuple of the hidden stats `h_t` and the last state of the iteration.
617
666
618
667
# Examples
619
668
@@ -625,24 +674,39 @@ h0 = zeros(Float32, d_out)
625
674
h = gru(x, h0) # out x len x batch_size
626
675
```
627
676
"""
628
- struct GRU{M}
677
+ struct GRU{S, M}
629
678
cell:: M
630
679
end
631
680
632
681
@layer :noexpand GRU
633
682
634
683
initialstates (gru:: GRU ) = initialstates (gru. cell)
635
684
636
- function GRU ((in, out):: Pair ; cell_kwargs... )
685
+ function GRU ((in, out):: Pair ; return_state = false , cell_kwargs... )
637
686
cell = GRUCell (in => out; cell_kwargs... )
638
- return GRU (cell)
687
+ return GRU {return_state, typeof(cell)} (cell)
688
+ end
689
+
690
+ function GRU (cell:: GRUCell ; return_state:: Bool = false )
691
+ GRU {return_state, typeof(cell)} (cell)
639
692
end
640
693
641
694
(gru:: GRU )(x:: AbstractArray ) = gru (x, initialstates (gru))
642
695
643
- function (m:: GRU )(x:: AbstractArray , h)
696
+ function (gru:: GRU{false} )(x:: AbstractArray , h)
697
+ @assert ndims (x) == 2 || ndims (x) == 3
698
+ return first (scan (gru. cell, x, h))
699
+ end
700
+
701
+ function (gru:: GRU{true} )(x:: AbstractArray , h)
644
702
@assert ndims (x) == 2 || ndims (x) == 3
645
- return scan (m. cell, x, h)
703
+ return scan (gru. cell, x, h)
704
+ end
705
+
706
+ function Functors. functor (gru:: GRU{S} ) where {S}
707
+ params = (cell = gru. cell,)
708
+ reconstruct = p -> GRU {S, typeof(p.cell)} (p. cell)
709
+ return params, reconstruct
646
710
end
647
711
648
712
function Base. show (io:: IO , m:: GRU )
@@ -739,7 +803,7 @@ Base.show(io::IO, m::GRUv3Cell) =
739
803
740
804
741
805
@doc raw """
742
- GRUv3(in => out; init_kernel = glorot_uniform,
806
+ GRUv3(in => out; return_state = false, init_kernel = glorot_uniform,
743
807
init_recurrent_kernel = glorot_uniform, bias = true)
744
808
745
809
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
@@ -764,6 +828,7 @@ but only a less popular variant.
764
828
# Arguments
765
829
766
830
- `in => out`: The input and output dimensions of the layer.
831
+ - `return_state`: Option to return the last state together with the output. Default is `false`.
767
832
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
768
833
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
769
834
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
@@ -778,7 +843,8 @@ The arguments of the forward pass are:
778
843
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
779
844
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
780
845
781
- Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
846
+ Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
847
+ a tuple of the hidden stats `h_t` and the last state of the iteration.
782
848
783
849
# Examples
784
850
@@ -790,24 +856,39 @@ h0 = zeros(Float32, d_out)
790
856
h = gruv3(x, h0) # out x len x batch_size
791
857
```
792
858
"""
793
- struct GRUv3{M}
859
+ struct GRUv3{S, M}
794
860
cell:: M
795
861
end
796
862
797
863
@layer :noexpand GRUv3
798
864
799
865
initialstates (gru:: GRUv3 ) = initialstates (gru. cell)
800
866
801
- function GRUv3 ((in, out):: Pair ; cell_kwargs... )
867
+ function GRUv3 ((in, out):: Pair ; return_state = false , cell_kwargs... )
802
868
cell = GRUv3Cell (in => out; cell_kwargs... )
803
- return GRUv3 (cell)
869
+ return GRUv3 {return_state, typeof(cell)} (cell)
870
+ end
871
+
872
+ function GRUv3 (cell:: GRUv3Cell ; return_state:: Bool = false )
873
+ GRUv3 {return_state, typeof(cell)} (cell)
804
874
end
805
875
806
876
(gru:: GRUv3 )(x:: AbstractArray ) = gru (x, initialstates (gru))
807
877
808
- function (m :: GRUv3 )(x:: AbstractArray , h)
878
+ function (gru :: GRUv3{false} )(x:: AbstractArray , h)
809
879
@assert ndims (x) == 2 || ndims (x) == 3
810
- return scan (m. cell, x, h)
880
+ return first (scan (gru. cell, x, h))
881
+ end
882
+
883
+ function (gru:: GRUv3{true} )(x:: AbstractArray , h)
884
+ @assert ndims (x) == 2 || ndims (x) == 3
885
+ return scan (gru. cell, x, h)
886
+ end
887
+
888
+ function Functors. functor (gru:: GRUv3{S} ) where {S}
889
+ params = (cell = gru. cell,)
890
+ reconstruct = p -> GRUv3 {S, typeof(p.cell)} (p. cell)
891
+ return params, reconstruct
811
892
end
812
893
813
894
function Base. show (io:: IO , m:: GRUv3 )
0 commit comments