Description
When using a Keras SimpleRNN
layer with stateful=True
, the documentation states that a fixed batch size is required. However, when a model is built with batch_size=N
(e.g. 4), and later called with an input of batch_size=1
, no error is raised. Instead, the RNN broadcasts the single input trajectory across all internal state slots, silently overwriting the state for all batches.
This violates the expected behavior of stateful=True
, where the batch size must remain fixed, and state slot i
must map to input sample i
.
The issue is not documented, leads to incorrect behavior, and can silently corrupt stateful models like ESNs or any RNN with memory across batches.
Standalone code to reproduce:
import tensorflow as tf
from tensorflow.keras import layers, Input, Model
# Build stateful RNN with batch_size=4
inputs = Input(shape=(5, 3), batch_size=4)
rnn = layers.SimpleRNN(10, stateful=True, return_sequences=False, name="rnn")
x = rnn(inputs)
model = Model(inputs, x)
# Manually set initial state to distinguishable values
state = rnn.states[0]
for i in range(4):
state[i].assign(tf.ones_like(state[i]) * i)
print("Initial state:")
print(state.numpy())
# Call model with a different batch size (1)
print("\nCalling model with input of batch size 1...")
_ = model(tf.random.normal((1, 5, 3)))
# Print new state
print("\nState after call:")
print(state.numpy())
Expected behaviour
An exception should be raised when the input batch size does not match the model's fixed batch_size
, particularly when stateful=True
.
Additional Info
- Tensorflow version: 2.19.0
- Keras version: 3.9.2
- OS: Kde-Neon 6.3 (Ubuntu 24.04 based)
- GPU: NVIDIA GeForce RTX 3060