Skip to content

Commit d034e0b

Browse files
committed
black
1 parent 224f4c5 commit d034e0b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch_forecasting/data/timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,9 +1569,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
15691569

15701570
# switch some variables to nan if encode length is 0
15711571
if encoder_length == 0 and len(self.dropout_categoricals) > 0:
1572-
data_cat[:, [self.flat_categoricals.index(c) for c in self.dropout_categoricals]] = (
1573-
0 # zero is encoded nan
1574-
)
1572+
fc = self.flat_categoricals
1573+
dc = self.dropout_categoricals
1574+
data_cat[:, [fc.index(c) for c in dc]] = 0 # zero is encoded nan
15751575

15761576
assert decoder_length > 0, "Decoder length should be greater than 0"
15771577
assert encoder_length >= 0, "Encoder length should be at least 0"

0 commit comments

Comments
 (0)