-
Notifications
You must be signed in to change notification settings - Fork 350
Open
Description
Are there any examples of how to infer s4 block in recurrent mode? I tried using the step function, but it gives errors. I'm attaching my script. What could be the problem?
import torch
from s4 import S4
from sashimi import ResidualBlock
def s4_block(dim):
layer = S4(
d_model=dim,
d_state=16,
bidirectional=False,
dropout=0.0,
transposed=True,
)
return ResidualBlock(
d_model=dim,
layer=layer,
dropout=0.0,
)
model = s4_block(16)
for module in model.modules():
if hasattr(module, 'setup_step'): module.setup_step(mode="diagonal")
model.eval()
input_seg = torch.randn(1, 16, 100)
full_out, _ = model(input_seg)
print(full_out)
s4_state = model.default_state()
stream_res = []
for i in range(input_seg.shape[-1]):
part_input = input_seg[:, :, i]
print(part_input.shape)
part_res, s4_state = model.step(part_input, s4_state)
stream_res.append(part_res)
stream_res = torch.cat(stream_res, dim=2)
print(stream_res)
print(torch.allclose(full_out, stream_res))
Metadata
Metadata
Assignees
Labels
No labels