Open
Description
Just create this dataset:
import numpy as np
import pandas as pd
multi_target_test_data = pd.DataFrame(
dict(
target1=np.random.rand(30),
target2=np.random.rand(30),
group=np.repeat(np.arange(3), 10),
time_idx=np.tile(np.arange(10), 3),
)
)
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import EncoderNormalizer, MultiNormalizer, TorchNormalizer
# create the dataset from the pandas dataframe
dataset = TimeSeriesDataSet(
multi_target_test_data,
group_ids=["group"],
target=["target1", "target2"], # USING two targets
time_idx="time_idx",
min_encoder_length=5,
max_encoder_length=5,
min_prediction_length=2,
max_prediction_length=2,
time_varying_unknown_reals=["target1", "target2"],
target_normalizer=MultiNormalizer(
[EncoderNormalizer(), TorchNormalizer()]
), # Use the NaNLabelEncoder to encode categorical target
)
And input it to the current LSTMModel
in the tutorials:
model = LSTMModel.from_dataset(
dataset,
n_layers=2,
hidden_size=10,
loss=MultiLoss([MAE() for _ in range(2)]),
)
x, y = next(iter(dataset.to_dataloader()))
print(
"prediction shape in training:", model(x)["prediction"].size()
) # batch_size x decoder time steps x 1 (1 for one target dimension)
model.eval() # set model into eval mode to use autoregressive prediction
print("prediction shape in inference:", model(x)["prediction"].size()) # should be the same as in training
And you'll get:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-13-037e20e20342> in <cell line: 3>()
2
3 print(
----> 4 "prediction shape in training:", model(x)["prediction"].size()
5 ) # batch_size x decoder time steps x 1 (1 for one target dimension)
6 model.eval() # set model into eval mode to use autoregressive prediction
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
<ipython-input-11-af0ffbd16c05> in forward(self, x)
105
106 def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
--> 107 hidden_state = self.encode(x) # encode to hidden state
108 output = self.decode(x, hidden_state) # decode leveraging hidden state
109
<ipython-input-11-af0ffbd16c05> in encode(self, x)
51 effective_encoder_lengths = x["encoder_lengths"] - 1
52 # run through LSTM network
---> 53 _, hidden_state = self.lstm(
54 input_vector, lengths=effective_encoder_lengths, enforce_sorted=False # passing the lengths directly
55 ) # second ouput is not needed (hidden state)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/nn/rnn.py in forward(self, x, hx, lengths, enforce_sorted)
105 else:
106 pack_lengths = lengths.where(lengths > 0, torch.ones_like(lengths))
--> 107 packed_out, hidden_state = super().forward(
108 rnn.pack_padded_sequence(
109 x, pack_lengths.cpu(), enforce_sorted=enforce_sorted, batch_first=self.batch_first
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
912 self.dropout, self.training, self.bidirectional, self.batch_first)
913 else:
--> 914 result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
915 self.num_layers, self.dropout, self.training, self.bidirectional)
916 output = result[0]
RuntimeError: mat1 and mat2 shapes cannot be multiplied (48x2 and 1x40)
It's just weird that there is still no fix for this, and no LSTM model out-of-the-box. I even made a fix, there is a PR.
Why does no one care about fixing this?
It is totally obscure how pytorch_forecasting
handles uni-/multi-targets, I've also noticed that if you pass target=["target"]
to TimeSeriesDataSet
, the TimeSeriesDataSet
behaves very differently w.r.t. if you passed target="target"
.
Please just review that PR and even merge it, or fix it...
Metadata
Metadata
Assignees
Type
Projects
Status
Needs triage & validation