Description
From the paper:
"we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"
"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."
Current implementation:
torchdeq/torchdeq/utils/init.py
Lines 4 to 21 in 4f6bd5f
It seems more appropriate to do this instead to match the paper.
*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)
This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.