From bd09052d15d30709d0a37954ed678033d7a41757 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 21 Feb 2025 20:06:12 +0000 Subject: [PATCH 01/16] air quality dataset --- tests/data/air_quality/data.csv | 4 ++ tests/data/air_quality/data.py | 33 +++++++++++ tests/datasets/test_air_quality.py | 33 +++++++++++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/air_quality.py | 90 ++++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+) create mode 100644 tests/data/air_quality/data.csv create mode 100644 tests/data/air_quality/data.py create mode 100644 tests/datasets/test_air_quality.py create mode 100644 torchgeo/datasets/air_quality.py diff --git a/tests/data/air_quality/data.csv b/tests/data/air_quality/data.csv new file mode 100644 index 00000000000..03cee34b70e --- /dev/null +++ b/tests/data/air_quality/data.csv @@ -0,0 +1,4 @@ +Date,Time,CO(GT),PT08.S1(CO),NMHC(GT),C6H6(GT),PT08.S2(NMHC),NOx(GT),PT08.S3(NOx),NO2(GT),PT08.S4(NO2),PT08.S5(O3),T,RH,AH +0.40370411941917805,0.3606056015367948,0.21516497242102428,0.24387342841491255,0.5329766435406001,0.9236204533706107,0.7028490254579889,0.29018687794493325,0.7725653430232418,0.1983454164719889,0.9298872870584386,0.263218948551846,0.8384843804704847,0.7555438109971452,0.11478955051245954 +0.4552266119955529,0.9118213335477183,0.5596848487112112,0.4916726172107061,0.8801002686394488,0.015590845376925677,0.6772789111200282,0.5028828842386969,0.8247776505970225,0.24194183991712026,0.7979902725334804,0.9492633578318296,0.8506183666485868,0.9907493688105276,0.9855916371135124 +0.8556168349254485,0.9603650305139044,0.005755231837792918,0.3680571428086127,0.49204221555453187,0.667089058846023,0.10871963316348399,0.09555947744260707,0.8056372516222555,0.16947203544585387,0.9094430927361661,0.917200867710488,0.10438373241272314,0.29439189601566407,0.247751757906712 diff --git a/tests/data/air_quality/data.py b/tests/data/air_quality/data.py new file mode 100644 index 00000000000..be55a065299 --- /dev/null +++ b/tests/data/air_quality/data.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd + +columns = [ + 'Date', + 'Time', + 'CO(GT)', + 'PT08.S1(CO)', + 'NMHC(GT)', + 'C6H6(GT)', + 'PT08.S2(NMHC)', + 'NOx(GT)', + 'PT08.S3(NOx)', + 'NO2(GT)', + 'PT08.S4(NO2)', + 'PT08.S5(O3)', + 'T', + 'RH', + 'AH', +] + +nrows = 3 +data = np.random.rand(nrows, len(columns)) + +df = pd.DataFrame(data, columns=columns) + + +df.to_csv('data.csv', index=False) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py new file mode 100644 index 00000000000..067f1c56519 --- /dev/null +++ b/tests/datasets/test_air_quality.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path + +import pandas as pd +import pytest +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import AirQuality, DatasetNotFoundError + + +class TestAirQuality: + @pytest.fixture() + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> AirQuality: + url = 'tests/data/air_quality/data.csv' + monkeypatch.setattr(AirQuality, 'url', url) + return AirQuality(tmp_path, download=True) + + def test_getitem(self, dataset: AirQuality) -> None: + x = dataset[0] + assert isinstance(x, pd.Series) + assert len(x) == 15 + + def test_len(self, dataset: AirQuality) -> None: + assert len(dataset) == 3 + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + AirQuality(tmp_path) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0e522c09976..89dfca1dc0b 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -6,6 +6,7 @@ from .advance import ADVANCE from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity from .agrifieldnet import AgriFieldNet +from .air_quality import AirQuality from .airphen import Airphen from .astergdem import AsterGDEM from .benin_cashews import BeninSmallHolderCashews @@ -176,6 +177,7 @@ 'VHR10', 'AbovegroundLiveWoodyBiomassDensity', 'AgriFieldNet', + 'AirQuality', 'Airphen', 'AsterGDEM', 'BeninSmallHolderCashews', diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py new file mode 100644 index 00000000000..04d90912fa2 --- /dev/null +++ b/torchgeo/datasets/air_quality.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Air Quality dataset.""" + +import os + +import pandas as pd + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path + + +class AirQuality(NonGeoDataset): + """Air Quality dataset. + + The `Air Quality dataset `_ + from the UCI Machine Learning Repository is a multivariate time + series dataset containing air quality measurements from an Italian + city. + + Dataset Format: + + * .csv file containing date, time and air quality measurements + + Dataset Features: + + * hourly averaged sensor responses and reference analyzer ground truth over one year (2004-2005) + * has missing features + + If you use this dataset in your research, please cite: + + * https://doi.org/10.1016/J.SNB.2007.09.060 + + .. versionadded:: 0.7 + """ + + url = 'https://archive.ics.uci.edu/static/public/360/data.csv' + data_file_name = 'data.csv' + + def __init__(self, root: Path = 'data', download: bool = False) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + download: if True, download dataset and store it in the root directory + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.download = download + self.data = self._load_data() + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.data) + + def __getitem__(self, index: int) -> pd.Series: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data at that index + """ + return self.data.iloc[index] + + def _load_data(self) -> pd.DataFrame: + """Load the dataset into a pandas dataframe. + + Returns: + Dataframe containing the data. + """ + # Check if the file already exists + pathname = os.path.join(self.root, self.data_file_name) + if os.path.exists(pathname): + return pd.read_csv(pathname) + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + return pd.read_csv(self.url) From ee876385cd232096a9328dbe94d14e7516902434 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 28 Feb 2025 18:56:53 +0000 Subject: [PATCH 02/16] seq2seq model and tests initial commit --- tests/models/test_seq2seq.py | 33 +++++++++++++++ torchgeo/models/__init__.py | 2 + torchgeo/models/seq2seq.py | 80 ++++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+) create mode 100644 tests/models/test_seq2seq.py create mode 100644 torchgeo/models/seq2seq.py diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py new file mode 100644 index 00000000000..d73cbe4231f --- /dev/null +++ b/tests/models/test_seq2seq.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest +import torch + +from torchgeo.models import LSTMSeq2Seq + +BATCH_SIZE = [1, 2] +INPUT_SIZE_ENCODER = [1, 3] +INPUT_SIZE_DECODER = [1] +OUTPUT_SIZE = [1] + + +class TestLSTMSeq2Seq: + @torch.no_grad() + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('e', INPUT_SIZE_ENCODER) + @pytest.mark.parametrize('d', INPUT_SIZE_DECODER) + def test_input_size(self, b: int, e: int, d: int) -> None: + sequence_length = 3 + output_sequence_length = 1 + output_size = 1 + model = LSTMSeq2Seq( + input_size_encoder=e, + input_size_decoder=d, + output_size=output_size, + output_seq_length=output_sequence_length, + ) + inputs_encoder = torch.randn(b, sequence_length, e) + inputs_decoder = torch.randn(b, output_sequence_length, d) + y = model(inputs_encoder, inputs_decoder) + assert y.shape == (b, output_sequence_length, output_size) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 539be67180a..d920cda66e9 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -28,6 +28,7 @@ resnet152, ) from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16 +from .seq2seq import LSTMSeq2Seq from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ViTSmall16_Weights, vit_small_patch16_224 @@ -46,6 +47,7 @@ 'FCSiamConc', 'FCSiamDiff', 'FarSeg', + 'LSTMSeq2Seq', 'ResNet18_Weights', 'ResNet50_Weights', 'ResNet152_Weights', diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py new file mode 100644 index 00000000000..719e169dec5 --- /dev/null +++ b/torchgeo/models/seq2seq.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LSTM Sequence to Sequence (Seq2Seq) Model.""" + +import torch +import torch.nn as nn +from torch import Tensor + + +class LSTMEncoder(nn.Module): + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1) -> None: + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + + def forward(self, x: Tensor): + # Only keep the last hidden and cell states + _, (hidden, cell) = self.lstm(x) + return hidden, cell + + +class LSTMDecoder(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + target_indices: list[int] | None, + num_layers: int = 1, + output_sequence_len: int = 1, + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + self.fc = nn.Linear(hidden_size, output_size) + self.output_size = output_size + self.target_indices = target_indices + self.output_sequence_len = output_sequence_len + + def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: + # shouldn't this be shape[0] since batch_first = True? + batch_size = inputs.shape[0] + outputs = torch.zeros(batch_size, self.output_sequence_len, self.output_size) + + curr_input = inputs[:, 0:1, :] + + for t in range(self.output_sequence_len): + print(f'input_t: {curr_input.shape}') + _, (hidden, cell) = self.lstm(curr_input, (hidden, cell)) + output = self.fc(hidden) # Predict next step + outputs[:, t, :] = output + curr_input = output + + return outputs + + +class LSTMSeq2Seq(nn.Module): + def __init__( + self, + input_size_encoder: int, + input_size_decoder: int, + hidden_size: int = 1, + output_size: int = 1, + output_seq_length: int = 1, + num_layers: int = 1, + ) -> None: + super().__init__() + self.encoder = LSTMEncoder(input_size_encoder, hidden_size, num_layers) + self.decoder = LSTMDecoder( + input_size_decoder, + hidden_size, + output_size, + target_indices=None, + num_layers=num_layers, + output_sequence_len=output_seq_length, + ) + + def forward(self, inputs_encoder: Tensor, inputs_decoder: Tensor) -> Tensor: + hidden, cell = self.encoder(inputs_encoder) + outputs = self.decoder(inputs_decoder, hidden, cell) + return outputs From 3b754ada3ec5db39d387ebac049a188fe987e66b Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Sat, 8 Mar 2025 06:16:05 +0000 Subject: [PATCH 03/16] additions to seq2seq model --- tests/models/test_seq2seq.py | 18 ++++++----- torchgeo/models/seq2seq.py | 58 ++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py index d73cbe4231f..aa4e0159dae 100644 --- a/tests/models/test_seq2seq.py +++ b/tests/models/test_seq2seq.py @@ -6,9 +6,9 @@ from torchgeo.models import LSTMSeq2Seq -BATCH_SIZE = [1, 2] +BATCH_SIZE = [1, 2, 7] INPUT_SIZE_ENCODER = [1, 3] -INPUT_SIZE_DECODER = [1] +INPUT_SIZE_DECODER = [2, 3] OUTPUT_SIZE = [1] @@ -19,15 +19,19 @@ class TestLSTMSeq2Seq: @pytest.mark.parametrize('d', INPUT_SIZE_DECODER) def test_input_size(self, b: int, e: int, d: int) -> None: sequence_length = 3 - output_sequence_length = 1 - output_size = 1 + output_sequence_length = 3 + n_features = 5 + output_size = 2 model = LSTMSeq2Seq( input_size_encoder=e, input_size_decoder=d, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, e)), + decoder_indices=list(range(0, d)), output_size=output_size, output_seq_length=output_sequence_length, ) - inputs_encoder = torch.randn(b, sequence_length, e) - inputs_decoder = torch.randn(b, output_sequence_length, d) - y = model(inputs_encoder, inputs_decoder) + past_steps = torch.randn(b, sequence_length, n_features) + future_steps = torch.randn(b, output_sequence_length, n_features) + y = model(past_steps, future_steps) assert y.shape == (b, output_sequence_length, output_size) diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index 719e169dec5..f48f415e666 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -3,6 +3,9 @@ """LSTM Sequence to Sequence (Seq2Seq) Model.""" +import random +from typing import cast + import torch import torch.nn as nn from torch import Tensor @@ -13,8 +16,7 @@ def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1) -> No super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) - def forward(self, x: Tensor): - # Only keep the last hidden and cell states + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: _, (hidden, cell) = self.lstm(x) return hidden, cell @@ -25,9 +27,10 @@ def __init__( input_size: int, hidden_size: int, output_size: int, - target_indices: list[int] | None, + target_indices: list[int], num_layers: int = 1, output_sequence_len: int = 1, + teacher_force_prob: float | None = None, ) -> None: super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) @@ -35,20 +38,27 @@ def __init__( self.output_size = output_size self.target_indices = target_indices self.output_sequence_len = output_sequence_len + self.teacher_force_prob = teacher_force_prob def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: - # shouldn't this be shape[0] since batch_first = True? batch_size = inputs.shape[0] outputs = torch.zeros(batch_size, self.output_sequence_len, self.output_size) - curr_input = inputs[:, 0:1, :] + current_input = inputs[:, 0:1, :] for t in range(self.output_sequence_len): - print(f'input_t: {curr_input.shape}') - _, (hidden, cell) = self.lstm(curr_input, (hidden, cell)) - output = self.fc(hidden) # Predict next step - outputs[:, t, :] = output - curr_input = output + _, (hidden, cell) = self.lstm(current_input, (hidden, cell)) + output = self.fc(hidden) + output = output.permute(1, 0, 2) + outputs[:, t : t + 1, :] = output + current_input = inputs[:, t : t + 1, :].clone() + teacher_force = ( + random.random() < self.teacher_force_prob + if self.teacher_force_prob is not None + else False + ) + if not teacher_force: + current_input[:, :, self.target_indices] = output return outputs @@ -58,23 +68,45 @@ def __init__( self, input_size_encoder: int, input_size_decoder: int, + target_indices: list[int], + encoder_indices: list[int] | None = None, + decoder_indices: list[int] | None = None, hidden_size: int = 1, output_size: int = 1, output_seq_length: int = 1, num_layers: int = 1, ) -> None: super().__init__() + # Target indices need to be mapped to the subset of inputs for decoder + mapped_target_indices = ( + torch.nonzero( + torch.isin(torch.tensor(decoder_indices), torch.tensor(target_indices)) + ) + .squeeze() + .tolist() + ) self.encoder = LSTMEncoder(input_size_encoder, hidden_size, num_layers) self.decoder = LSTMDecoder( input_size_decoder, hidden_size, output_size, - target_indices=None, + mapped_target_indices, num_layers=num_layers, output_sequence_len=output_seq_length, ) + self.encoder_indices = encoder_indices + self.decoder_indices = decoder_indices - def forward(self, inputs_encoder: Tensor, inputs_decoder: Tensor) -> Tensor: + def forward(self, past_steps: Tensor, future_steps: Tensor) -> Tensor: + if self.encoder_indices: + inputs_encoder = past_steps[:, :, self.encoder_indices] + else: + inputs_encoder = past_steps + inputs_decoder = torch.cat( + [past_steps[:, -1, :].unsqueeze(1), future_steps], dim=1 + ) + if self.decoder_indices: + inputs_decoder = inputs_decoder[:, :, self.decoder_indices] hidden, cell = self.encoder(inputs_encoder) outputs = self.decoder(inputs_decoder, hidden, cell) - return outputs + return cast(Tensor, outputs) From 30407069f1f042f280799b39575fa82dd2bccf48 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Sat, 8 Mar 2025 06:22:52 +0000 Subject: [PATCH 04/16] autoregression trainer initial commit --- torchgeo/trainers/autoregression.py | 186 ++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 torchgeo/trainers/autoregression.py diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py new file mode 100644 index 00000000000..719ce5854f9 --- /dev/null +++ b/torchgeo/trainers/autoregression.py @@ -0,0 +1,186 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for autoregression.""" + +from typing import Any + +import torch +import torch.nn as nn +from torch import Tensor +from torchmetrics import MetricCollection +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from torchvision.models import LSTMSeq2Seq +from torchvision.models._api import WeightsEnum + +from .base import BaseTask + + +class AutoregressionTask(BaseTask): + """Autoregression.""" + + def __init__( + self, + model: str = 'lstm', + weights: WeightsEnum | str | bool | None = None, + input_size: int = 1, + input_size_decoder: int = 1, + hidden_size: int = 1, + output_size: int = 1, + lookback: int = 3, + timesteps_ahead: int = 1, + num_layers: int = 1, + loss: str = 'mse', + lr: float = 1e-3, + patience: int = 10, + ) -> None: + """Initialize a new AutoregressionTask instance. + + Args: + model: Name of the model to use, currently supports 'lstm' or 'seq2seq'. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False + or None for random weights, or the path to a saved model state dict. + loss: One of 'mse' or 'mae'. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + + .. versionadded: 0.7 + """ + super().__init__() + + def configure_models(self) -> None: + """Initialize the model.""" + model: str = self.hparams['model'] + input_size = self.hparams['input_size'] + input_size_decoder = self.hparams['input_size_decoder'] + hidden_size = self.hparams['hidden_size'] + output_size = self.hparams['output_size'] + lookback = self.hparams['lookback'] + timesteps_ahead = self.hparams['timesteps_ahead'] + num_layers = self.hparams['num_layers'] + + if model == 'lstm': + assert timesteps_ahead == 1, ( + f'LSTM only supports 1 timestep ahead, got timesteps_ahead={timesteps_ahead}.' + ) + self.model = torch.nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + elif model == 'seq2seq': + self.model = LSTMSeq2Seq( + input_size_encoder=input_size, + input_size_decoder=input_size_decoder, + hidden_size=hidden_size, + output_size=output_size, + output_seq_length=timesteps_ahead, + num_layers=num_layers, + ) + else: + raise ValueError( + f"Model type '{model}' is not valid. " + "Currently, only supports 'lstm' and 'seq2seq'." + ) + + def configure_losses(self) -> None: + """Initialize the loss criterion. + + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams['loss'] + if loss == 'mse': + self.criterion = nn.MSELoss() + else: + raise ValueError( + f"Loss type '{loss}' is not valid. Currently, supports 'mse' loss." + ) + + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + output_size = self.hparams['output_size'] + metrics = MetricCollection( + { + 'rmse': MeanSquaredError(num_outputs=output_size, squared=False), + 'mae': MeanAbsoluteError(num_outputs=output_size), + } + ) + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') + + def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: + """Compute the loss and additional metrics for the given stage. + + Args: + batch: The output of your DataLoader._ + batch_idx: Integer displaying index of this batch._ + stage: The current stage. + + Returns: + The loss tensor. + """ + x, y = batch + y_hat = self(x) + + loss: Tensor = self.criterion(y_hat, y) + self.log(f'{stage}_loss', loss) + + # Retrieve the correct metrics based on the stage + metrics = getattr(self, f'{stage}_metrics', None) + if metrics: + metrics(y_hat, y) + self.log_dict({f'{k}': v for k, v in metrics.compute().items()}) + + return loss + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + loss = self._shared_step(batch, batch_idx, 'train') + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'val') + + def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'test') + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted regression values. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + + Returns: + Output predicted values. + """ + x = batch + y_hat: Tensor = self(x) + return y_hat From a7284437caadf45bbcd1e09afd6d6f5d602c853d Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Tue, 11 Mar 2025 17:48:11 +0000 Subject: [PATCH 05/16] further additions to trainer, seq2seq, and air quality dataset and datamodule. --- tests/conf/air_quality.yaml | 16 ++++++ tests/data/air_quality/data.csv | 53 +++++++++++++++++-- tests/data/air_quality/data.py | 2 +- tests/datasets/test_air_quality.py | 12 +++-- tests/models/test_seq2seq.py | 52 +++++++++++++++++++ tests/trainers/test_autoregression.py | 38 ++++++++++++++ torchgeo/datamodules/__init__.py | 3 +- torchgeo/datamodules/air_quality.py | 75 +++++++++++++++++++++++++++ torchgeo/datasets/air_quality.py | 29 +++++++++-- torchgeo/models/seq2seq.py | 17 +++--- torchgeo/trainers/__init__.py | 2 + torchgeo/trainers/autoregression.py | 53 ++++++++++--------- 12 files changed, 309 insertions(+), 43 deletions(-) create mode 100644 tests/conf/air_quality.yaml create mode 100644 tests/trainers/test_autoregression.py create mode 100644 torchgeo/datamodules/air_quality.py diff --git a/tests/conf/air_quality.yaml b/tests/conf/air_quality.yaml new file mode 100644 index 00000000000..5ef4da2c2c2 --- /dev/null +++ b/tests/conf/air_quality.yaml @@ -0,0 +1,16 @@ +model: + class_path: AutoregressionTask + init_args: + loss: 'mse' + model: 'lstm_seq2seq' + input_size: 3 + input_size_decoder: 1 + target_indices: [2] + encoder_indices: [2, 12, 13] + decoder_indices: [2] +data: + class_path: AirQualityDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/air_quality' \ No newline at end of file diff --git a/tests/data/air_quality/data.csv b/tests/data/air_quality/data.csv index 03cee34b70e..789142c6274 100644 --- a/tests/data/air_quality/data.csv +++ b/tests/data/air_quality/data.csv @@ -1,4 +1,51 @@ Date,Time,CO(GT),PT08.S1(CO),NMHC(GT),C6H6(GT),PT08.S2(NMHC),NOx(GT),PT08.S3(NOx),NO2(GT),PT08.S4(NO2),PT08.S5(O3),T,RH,AH -0.40370411941917805,0.3606056015367948,0.21516497242102428,0.24387342841491255,0.5329766435406001,0.9236204533706107,0.7028490254579889,0.29018687794493325,0.7725653430232418,0.1983454164719889,0.9298872870584386,0.263218948551846,0.8384843804704847,0.7555438109971452,0.11478955051245954 -0.4552266119955529,0.9118213335477183,0.5596848487112112,0.4916726172107061,0.8801002686394488,0.015590845376925677,0.6772789111200282,0.5028828842386969,0.8247776505970225,0.24194183991712026,0.7979902725334804,0.9492633578318296,0.8506183666485868,0.9907493688105276,0.9855916371135124 -0.8556168349254485,0.9603650305139044,0.005755231837792918,0.3680571428086127,0.49204221555453187,0.667089058846023,0.10871963316348399,0.09555947744260707,0.8056372516222555,0.16947203544585387,0.9094430927361661,0.917200867710488,0.10438373241272314,0.29439189601566407,0.247751757906712 +0.428844960315107,0.9866608945950183,0.3705128349758361,0.06370825832902971,0.9984071519066181,0.23601028428833615,0.6711387944567729,0.7165114438855709,0.5840507533339877,0.25990180009319086,0.571315476202064,0.02832021978464172,0.037552099008017814,0.9816954186918593,0.6350429315448495 +0.7775475150998155,0.6659969353072417,0.39932933609779664,0.2622298457856952,0.9443611682964963,0.20875986148029524,0.32574791761079014,0.7862509708476118,0.6908314932342879,0.8361516998619839,0.6121843012493646,0.3798949072266635,0.9568554958054666,0.02607922923897077,0.635026275140384 +0.9039645671894101,0.897744879713676,0.2881787242258511,0.002372804476408641,0.9730300845182906,0.09010979656394669,0.09385852352265911,0.8314209560340037,0.3113554033328697,0.35221348466336844,0.019181890067545115,0.7108240805112648,0.6325030745529099,0.2536998130045216,0.695415273113407 +0.3059428706018087,0.857795345507438,0.16818645291766787,0.6665326676398685,0.34171307776576454,0.028318962518144142,0.966118039297472,0.6900530586626799,0.44300982245907516,0.8091369572525287,0.11479260604867825,0.6455647318157319,0.5324770086205174,0.0481917296021982,0.633141760683114 +0.35765564697096386,0.7858053743605675,0.3104016522107257,0.493416228005473,0.1483762812702656,0.7241500129800955,0.015896064761218742,0.44323625781500997,0.22656208552155188,0.5764696354888218,0.7836345188253819,0.395272511563161,0.8211200761625238,0.8825035849054379,0.7022012756248662 +0.5223180133243563,0.19032196352865638,0.2376350265360937,0.02801444954355481,0.7262114396783,0.4855480945003384,0.22914947905749994,0.6105985226756249,0.4299963391209827,0.016140740391492603,0.6927601581422852,0.34532464557082765,0.7054955160261106,0.6635612542374121,0.950276802833115 +0.6335493777298301,0.9990439551776903,0.8576504193950082,0.10799687272757574,0.6728234503646858,0.6984180171198154,0.36591757674293135,0.3927919402783018,0.48583164599457185,0.27880177963877484,0.8086519930928119,0.958232169659469,0.046577978803069686,0.9229631695109645,0.7761290479705539 +0.7534312680544961,0.17968083857978556,0.9164541915202609,0.9642421245611333,0.9165621403348179,0.28883554395073907,0.542156663510998,0.7912965980953718,0.892003147870078,0.17467182599762165,0.10036895690809988,0.292028304693698,0.41626796824868517,0.3291941225945081,0.004793952836596116 +0.4030229293020441,0.6241895804028355,0.8027489708177322,0.6061090924552509,0.49517547297881426,0.842015882546139,0.7996783050301155,0.7231551575961693,0.5605444237938936,0.725715256328619,0.8383988129619468,0.23793853519650288,0.7475719718283232,0.5186515383181854,0.8065088529876593 +0.9588023269759194,0.1754985026827054,0.5489916656547322,0.0422015379054006,0.33580950564507894,0.060679127163351776,0.8765016570612999,0.6035308449183396,0.24078472469497103,0.7597171836488161,0.48912582023988704,0.2453476539171492,0.7102981668858297,0.8772506412475322,0.8534851164967973 +0.4846229396514261,0.9436762481111945,0.06741844190243473,0.4618005148984925,0.5900581686288878,0.16831152439252683,0.7878945078391524,0.5360733064490945,0.6211485399827885,0.17500709594134145,0.697960750115879,0.13581109878852693,0.8156566690374971,0.242506637688075,0.48633996290588954 +0.3483524375314333,0.04300540733448299,0.9647797491796986,0.4613579175699183,0.9693808467706223,0.46749022608964597,0.5655973453407719,0.9873213674954243,0.8736119423601583,0.4237783071664204,0.5080632099038275,0.8632283365136613,0.5264136613838126,0.0434876036352726,0.3081429927207596 +0.6326611204832632,0.09147711621722876,0.6406961675831314,0.5248086249876812,0.772523573112128,0.7113592876716782,0.24915772018411464,0.9654421446013526,0.0032685908574943134,0.9364324001873182,0.9941463757396615,0.4556472971512966,0.9401087495394862,0.56161863415598,0.8121018570128568 +0.6063620622055907,0.09601965532476875,0.5631920459131398,0.9597996401070554,0.805223309981179,0.9110155339375118,0.90635556246564,0.629883065213367,0.6992874629069337,0.02755995522976451,0.6764407089152157,0.5147771063157597,0.08980091589916439,0.6468489058489089,0.4778744096276797 +0.7124181751859203,0.38112767817384985,0.06180169829266957,0.644533493507947,0.9736381656059738,0.2217158561329341,0.5807146315430849,0.29137729741005947,0.6000551984650088,0.11915249772454506,0.06507451919960028,0.9144070859628508,0.463730190931789,0.7364119603627168,0.8299907778984987 +0.7673507410857925,0.09015260397470759,0.3565636743608366,0.30149039227455776,0.5823222840186881,0.31774569361446403,0.31262666726144317,0.8919040873509467,0.9483652005692642,0.20486222460576364,0.5907118699600977,0.39700028234544493,0.8806661531424751,0.9096550586000391,0.9926732255621498 +0.28179392984634344,0.13208638219433366,0.46650838734796995,0.9693186846197664,0.9116492020343615,0.011169400499041582,0.7921130594859067,0.01010787552165282,0.23871477464776114,0.5039327493738484,0.4694618944757426,0.3320929088055071,0.9953005407830204,0.41721458831109315,0.6219705979188263 +0.5155778085824996,0.43864518002279873,0.3823433993193788,0.7205316487464971,0.6883140093469334,0.7174177831551652,0.6644014203675569,0.7320462334354929,0.44604977554236236,0.5925488379113077,0.053334404808935476,0.20635224895832338,0.9983571366293664,0.2462670173747903,0.32992146523553945 +0.19311061229302073,0.49177543376099964,0.8803946805135382,0.9942427746247304,0.4127104725455135,0.2855126643140573,0.20041244706800454,0.8332085753072233,0.5162171465890477,0.41722641817201556,0.2889597261508776,0.6453434231229928,0.048873272538541124,0.9352022597778659,0.7490170642965864 +0.023804537011409388,0.452960514401231,0.5316856419919115,0.8474423398747712,0.3212592954043607,0.04995160092661344,0.5335128741135348,0.9731243839111817,0.5646818375999015,0.2342174425568424,0.33523863282203825,0.6017411408124446,0.24632459841924303,0.9761407637655526,0.22339309515335026 +0.7710725412919975,0.1081915279745359,0.18287157220528694,0.03897648619204641,0.19564388113623565,0.7173265323695658,0.4534584642536357,0.25599394289870114,0.5055224980383046,0.5337342367427621,0.6434627637452554,0.395112816109344,0.29283240378518904,0.6734206006274873,0.5477890808804488 +0.5810843118756917,0.24701492777967693,0.021457590754809575,0.07930388248379416,0.8573797694607035,0.5849167765719225,0.20752144651399862,0.30720474523817765,0.901116555870858,0.7310787439353938,0.04149558398673425,0.45048315860983934,0.5355237069576362,0.935614982156612,0.14628978197266007 +0.9650881158768077,0.24393127671354498,0.023551931104118684,0.9276372334757385,0.646480637788412,0.8768325666870731,0.5031423256890551,0.6703405625328099,0.4240248411230091,0.8750823470850183,0.1521392487440415,0.8360195677538244,0.0029591011663731015,0.48700328371501844,0.5271555877737016 +0.4435095507554402,0.2580945634985652,0.05043531826353309,0.7485412853110855,0.36737530655511386,0.6279603473688744,0.7233713190729659,0.2873250341885155,0.9586053373211195,0.1919087197200391,0.5004314129592653,0.6978240080356658,0.5577517576652488,0.6386739524021738,0.9649284721294775 +0.9376060555213641,0.4731932398724885,0.5707952101562018,0.6772188964951886,0.9326699033686338,0.670545660770911,0.9382693295337488,0.9703174731409802,0.9330110811684368,0.14772715375852952,0.428214789686726,0.6993029816523798,0.5437099013249579,0.6446790166959826,0.3573838928746298 +0.5707099320299277,0.6105390935028076,0.8931108905714683,0.025771783679303994,0.7635038685554232,0.8565736240255146,0.7800324842310027,0.9429786595592813,0.8059731070278511,0.5879019395339954,0.607668827673365,0.9277821169731242,0.6723523734532479,0.0614473211053469,0.5299835659114396 +0.6542688458611743,0.41464460991830965,0.7729402924763508,0.34850320480829167,0.6262491120093998,0.2155710985275111,0.030447723696382156,0.4262638797185796,0.1566159218170904,0.04011593602983754,0.7468913828855264,0.35360126642453826,0.7406503827000928,0.362892362180792,0.39937108720089 +0.34814888816674316,0.371317694189513,0.5816652706959554,0.019843056042100238,0.23720370216161712,0.984938638285703,0.7292516592931112,0.902860667541096,0.30361097435474227,0.07043995932401148,0.00043250785765291955,0.3067321666735895,0.15503703704008376,0.4508939658276898,0.04422635705802691 +0.8013707630547462,0.4543895849721282,0.0878161993910207,0.7105661220120499,0.04291495984566385,0.03504390871446594,0.6709211024477577,0.22647570810062134,0.11262041875102746,0.043594084742591854,0.007320695592324067,0.2994001194857965,0.08934592081853354,0.6230726999829022,0.37832109880054654 +0.5926074012008012,0.21392691953315235,0.8311965562385258,0.20794116094371973,0.9196929513858264,0.21814349042248193,0.5073730808244645,0.6900149374774187,0.5554936110040637,0.0750903743733945,0.047030666967787904,0.9136331372467751,0.5485772707843245,0.5408508424428177,0.024359667511356986 +0.09520205248950353,0.5738344823681878,0.2949582515448169,0.48167057889012976,0.016506296051865488,0.9089393162718294,0.2566553321319921,0.5007944393018618,0.448992249441877,0.13231445877495374,0.6322972779427781,0.9733466533581665,0.2165501262249061,0.9568213178510563,0.28110652348139475 +0.8237066669097532,0.8471165746933589,0.18099919523734642,0.49240791194427347,0.06259737672750787,0.8687695915389871,0.4103852765282282,0.43051340316415043,0.4450826317417602,0.33431413370180796,0.7366931702446939,0.37023667782102987,0.36997671769412943,0.1211605644295436,0.45067417487640593 +0.9634590416861293,0.7858559518135985,0.74374331201771,0.4007820629942831,0.12760838775072902,0.4873951988693582,0.8382617560242147,0.25853447078049074,0.7062310003012338,0.6529408405729801,0.015096677702215677,0.9834982226646097,0.9790279136470549,0.6189204851779193,0.09843686931495055 +0.20756016119873177,0.5255535033514623,0.922524228081515,0.9524167195114229,0.7208207959866236,0.6263358071157747,0.9987651489461044,0.14252637081286645,0.3206156120113489,0.016239850655199173,0.4041479123760445,0.4801394852977854,0.7873993000519485,0.5789247169843984,0.908330878080155 +0.9449101753483371,0.9141091946660237,0.7353219908683188,0.6186834256394931,0.08208508579970453,0.10415261874375192,0.33033775801745013,0.8759942133922273,0.0754632293146652,0.39202430509741515,0.84948765696903,0.9369570239708863,0.7085089305846813,0.6394840001424527,0.5575563370878096 +0.8226919222111803,0.13187541224619426,0.4029478235262133,0.37927278505901063,0.4878771082623834,0.6476135656960138,0.4725905999089074,0.4820870534393096,0.394659598909027,0.33975182879399257,0.18084391036773562,0.2644746877299159,0.5847563968324464,0.5488046147210136,0.361241923043361 +0.6410311124370097,0.3497973529679822,0.002029420252289582,0.057120609545241785,0.8233519178931198,0.1486615196663975,0.3452357562689956,0.12855397342294383,0.1523238518367891,0.3637841435095427,0.37601442526675244,0.3466294898364505,0.46350955018604933,0.40981396586484764,0.9547491276934735 +0.7557995782420134,0.13938166018540565,0.8980747918079001,0.7475132607772889,0.9518481897309593,0.2792674103972418,0.24421651437977954,0.7440461333508422,0.3618842528331422,0.6656265553762212,0.2523816472721431,0.867421348926042,0.14146798447125164,0.034574521224249755,0.7648777623459267 +0.3842031181002018,0.9259401519433742,0.6038248606248895,0.18101150278973477,0.18493999595610444,0.7445627591713837,0.8246774369943273,0.9219681128803421,0.03163403931095876,0.48852950665712,0.9047015771978811,0.8365982676726381,0.9837997154674301,0.4431720967406755,0.3712699198717241 +0.1154678303876755,0.9748501368471806,0.8908722432221715,0.39002206687271523,0.648388907054587,0.6516268144864886,0.5060133488896946,0.6812514789142452,0.8579070634182451,0.19584068247011188,0.28662709488402704,0.7846868939029139,0.4955990004056613,0.5354020453253707,0.6188554810393168 +0.9483077066658636,0.4430884055243558,0.6924713594940829,0.48956882061910645,0.521336203627758,0.6476969423503891,0.7709252172905698,0.12381864993945102,0.026355814158331103,0.9150658700590858,0.8965855290857291,0.21033156502625427,0.812113859794264,0.8670513564729575,0.9334025575416065 +0.29003386891587546,0.002583257565877517,0.128343356016126,0.941302293442971,0.39124347787999947,0.5549173319887247,0.0523094288640068,0.07392732927434464,0.038144671304869204,0.5707768906320418,0.45911900889634394,0.6613425009683342,0.24759411985870616,0.31908555157715823,0.11721534069373796 +0.5664316020835152,0.14334773364182296,0.4529295572898214,0.1046958910061575,0.7494496144447053,0.08789050290634126,0.746795744424649,0.9591124916206604,0.1223501447804266,0.6307163040483119,0.39982050599826846,0.9892032960572019,0.5186771981689458,0.5762913360600099,0.5272328682003795 +0.7346780087212786,0.051447703355716246,0.8449316507290252,0.04418023093970436,0.8744012917193588,0.551478879696151,0.1914546183636482,0.6772930171190966,0.3107967304317856,0.47135893371458604,0.5030495004674282,0.5105814795423487,0.23040462851757537,0.20946394444589456,0.34397232932608846 +0.15980578559329262,0.645800371650627,0.493747966754033,0.41152140690252215,0.4279594977912953,0.8509963752455887,0.15749499359118324,0.11218835004422356,0.3666121752387994,0.3591375235539226,0.4638508315516717,0.5099300554121715,0.016464417215011795,0.6627269289769087,0.5112681809851871 +0.779073990645791,0.4034909145553488,0.4774375138263377,0.5364693360954615,0.47411902912267956,0.29412485942284317,0.2985117246367841,0.3930305441634453,0.31864073685611316,0.11431637997626176,0.9447337192997828,0.05837467683321207,0.47897997101741263,0.5216883647649706,0.9523317097755524 +0.5708313136977081,0.5742312134571957,0.6467799212747986,0.26731132999338336,0.8056737477239665,0.7766798565832207,0.4594552702403971,0.26420384174541645,0.7533243665885266,0.16023329741639725,0.006439386329007868,0.23928410420792245,0.2709107787699411,0.36257420433434073,0.22043537858865736 +0.9258046269753054,0.8631371265262906,0.6474672740726625,0.30004985635223236,0.8571535229305768,0.984310741040911,0.6689661764006155,0.694040170137931,0.4876996001989966,0.008091810167080049,0.4299861408970993,0.8733992599893273,0.13061136568760556,0.5850557847815535,0.7766135999214235 +0.4670476593108611,0.1372024937083567,0.7493084740727655,0.41298192320573635,0.47019633646249914,0.7402456268409223,0.3539403902466617,0.4622598186710626,0.90891617425517,0.4541346695071904,0.7501724599076772,0.7935496845388171,0.607718305225681,0.5712793025483863,0.6083405615117734 +0.3579651339033103,0.3931528213665786,0.694308850254078,0.0888443615941219,0.5256240051990227,0.6624479206726026,0.9155075607574713,0.19755433532146505,0.9663940494204895,0.6553927210697587,0.22126275306401544,0.16903787409372562,0.3172952633902072,0.5486777306291436,0.9080486931469579 diff --git a/tests/data/air_quality/data.py b/tests/data/air_quality/data.py index be55a065299..d5bad03e65e 100644 --- a/tests/data/air_quality/data.py +++ b/tests/data/air_quality/data.py @@ -24,7 +24,7 @@ 'AH', ] -nrows = 3 +nrows = 50 data = np.random.rand(nrows, len(columns)) df = pd.DataFrame(data, columns=columns) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py index 067f1c56519..86ef3b45cf8 100644 --- a/tests/datasets/test_air_quality.py +++ b/tests/datasets/test_air_quality.py @@ -21,12 +21,16 @@ def dataset( return AirQuality(tmp_path, download=True) def test_getitem(self, dataset: AirQuality) -> None: - x = dataset[0] - assert isinstance(x, pd.Series) - assert len(x) == 15 + x, y = dataset[0] + assert isinstance(x, pd.DataFrame) + assert len(x.columns) == 15 + assert len(x) == dataset.past_steps + assert isinstance(y, pd.DataFrame) + assert len(y.columns) == 15 + assert len(y) == dataset.future_steps def test_len(self, dataset: AirQuality) -> None: - assert len(dataset) == 3 + assert len(dataset) == 46 def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py index aa4e0159dae..90b729ca40b 100644 --- a/tests/models/test_seq2seq.py +++ b/tests/models/test_seq2seq.py @@ -10,6 +10,8 @@ INPUT_SIZE_ENCODER = [1, 3] INPUT_SIZE_DECODER = [2, 3] OUTPUT_SIZE = [1] +NUM_LAYERS = [1, 2, 3] +HIDDEN_SIZE = [1, 2, 3] class TestLSTMSeq2Seq: @@ -35,3 +37,53 @@ def test_input_size(self, b: int, e: int, d: int) -> None: future_steps = torch.randn(b, output_sequence_length, n_features) y = model(past_steps, future_steps) assert y.shape == (b, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('n', NUM_LAYERS) + def test_num_layers(self, n: int) -> None: + batch_size = 5 + input_size_encoder = 3 + input_size_decoder = 2 + sequence_length = 3 + output_sequence_length = 3 + n_features = 5 + output_size = 2 + model = LSTMSeq2Seq( + input_size_encoder=input_size_encoder, + input_size_decoder=input_size_decoder, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, input_size_encoder)), + decoder_indices=list(range(0, input_size_decoder)), + output_size=output_size, + output_seq_length=output_sequence_length, + num_layers=n, + ) + past_steps = torch.randn(batch_size, sequence_length, n_features) + future_steps = torch.randn(batch_size, output_sequence_length, n_features) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('h', HIDDEN_SIZE) + def test_hidden_size(self, h: int) -> None: + batch_size = 5 + input_size_encoder = 3 + input_size_decoder = 2 + sequence_length = 3 + output_sequence_length = 3 + n_features = 5 + output_size = 2 + model = LSTMSeq2Seq( + input_size_encoder=input_size_encoder, + input_size_decoder=input_size_decoder, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, input_size_encoder)), + decoder_indices=list(range(0, input_size_decoder)), + output_size=output_size, + output_seq_length=output_sequence_length, + hidden_size=h, + ) + past_steps = torch.randn(batch_size, sequence_length, n_features) + future_steps = torch.randn(batch_size, output_sequence_length, n_features) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) diff --git a/tests/trainers/test_autoregression.py b/tests/trainers/test_autoregression.py new file mode 100644 index 00000000000..a054a124fd9 --- /dev/null +++ b/tests/trainers/test_autoregression.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import MisconfigurationException +from torchgeo.main import main + + +class TestAutoregressionTask: + @pytest.mark.parametrize('name', ['air_quality']) + def test_trainer(self, name: str, fast_dev_run: bool) -> None: + config = os.path.join('tests', 'conf', name + '.yaml') + + args = [ + '--config', + config, + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', + str(fast_dev_run), + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', + ] + + main(['fit', *args]) + try: + main(['test', *args]) + except MisconfigurationException: + pass + try: + main(['predict', *args]) + except MisconfigurationException: + pass diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 6dd7231e3df..86f2d792535 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -4,6 +4,7 @@ """TorchGeo datamodules.""" from .agrifieldnet import AgriFieldNetDataModule +from .air_quality import AirQualityDataModule from .bigearthnet import BigEarthNetDataModule from .cabuar import CaBuArDataModule from .caffe import CaFFeDataModule @@ -57,7 +58,7 @@ from .xview import XView2DataModule __all__ = ( - 'AgriFieldNetDataModule', + 'AirQualityDataModuleAgriFieldNetDataModule', 'BaseDataModule', 'BigEarthNetDataModule', 'COWCCountingDataModule', diff --git a/torchgeo/datamodules/air_quality.py b/torchgeo/datamodules/air_quality.py new file mode 100644 index 00000000000..f4c9127be54 --- /dev/null +++ b/torchgeo/datamodules/air_quality.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Air Quality datamodule.""" + +from typing import Any + +from torch import Tensor +from torch.utils.data import Subset + +from ..datasets import AirQuality +from .geo import NonGeoDataModule + + +class AirQualityDataModule(NonGeoDataModule): + """LightningDataModule implementation for the AirQuality dataset. + + Uses the user provided splits to divide the dataset into + train/val/test sets. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + batch_size: int = 64, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new AirQualityDataModule instance. + + Args: + batch_size: Size of each mini-batch. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a testing set. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.AirQuality`. + """ + super().__init__(AirQuality, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + dataset = AirQuality(**self.kwargs) + train_split_pct = 1 - (self.val_split_pct + self.test_split_pct) + train_size = int(train_split_pct * len(dataset)) + val_size = int(self.val_split_pct * len(dataset)) + train_indices = range(train_size) + val_indices = range(train_size, train_size + val_size) + test_indices = range(train_size + val_size, len(dataset)) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) + self.test_dataset = Subset(dataset, test_indices) + + def on_after_batch_transfer( + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + return batch diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py index 04d90912fa2..8d1f77e5e9b 100644 --- a/torchgeo/datasets/air_quality.py +++ b/torchgeo/datasets/air_quality.py @@ -6,6 +6,7 @@ import os import pandas as pd +import torch from .errors import DatasetNotFoundError from .geo import NonGeoDataset @@ -39,7 +40,13 @@ class AirQuality(NonGeoDataset): url = 'https://archive.ics.uci.edu/static/public/360/data.csv' data_file_name = 'data.csv' - def __init__(self, root: Path = 'data', download: bool = False) -> None: + def __init__( + self, + root: Path = 'data', + download: bool = False, + num_past_steps: int = 3, + num_future_steps: int = 1, + ) -> None: """Initialize a new Dataset instance. Args: @@ -50,6 +57,8 @@ def __init__(self, root: Path = 'data', download: bool = False) -> None: """ self.root = root self.download = download + self.num_past_steps = num_past_steps + self.num_future_steps = num_future_steps self.data = self._load_data() def __len__(self) -> int: @@ -58,7 +67,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.data) + return len(self.data) - (self.num_past_steps + self.num_future_steps) def __getitem__(self, index: int) -> pd.Series: """Return an index within the dataset. @@ -69,7 +78,21 @@ def __getitem__(self, index: int) -> pd.Series: Returns: data at that index """ - return self.data.iloc[index] + past_steps = self.data.iloc[index : index + self.num_past_steps] + future_steps = self.data.iloc[ + index + self.num_past_steps : index + + self.num_past_steps + + self.num_future_steps + ] + past_steps = torch.tensor(past_steps.values, dtype=torch.float32) + future_steps = torch.tensor(future_steps.values, dtype=torch.float32) + + mean = past_steps.mean(dim=0, keepdim=True) + std = past_steps.std(dim=0, keepdim=True) + past_steps_normalized = (past_steps - mean) / (std + 1e-12) + future_steps_normalized = (future_steps - mean) / (std + 1e-12) + + return past_steps_normalized, future_steps_normalized def _load_data(self) -> pd.DataFrame: """Load the dataset into a pandas dataframe. diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index f48f415e666..f79d5f93c8d 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -48,8 +48,9 @@ def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: for t in range(self.output_sequence_len): _, (hidden, cell) = self.lstm(current_input, (hidden, cell)) - output = self.fc(hidden) - output = output.permute(1, 0, 2) + last_layer_hidden = hidden[-1:] + output = self.fc(last_layer_hidden) + output = output.permute(1, 0, 2) # put batch dimension first outputs[:, t : t + 1, :] = output current_input = inputs[:, t : t + 1, :].clone() teacher_force = ( @@ -75,6 +76,7 @@ def __init__( output_size: int = 1, output_seq_length: int = 1, num_layers: int = 1, + teacher_force_prob: int | None = None, ) -> None: super().__init__() # Target indices need to be mapped to the subset of inputs for decoder @@ -85,14 +87,17 @@ def __init__( .squeeze() .tolist() ) + if not isinstance(mapped_target_indices, list): + mapped_target_indices = [mapped_target_indices] self.encoder = LSTMEncoder(input_size_encoder, hidden_size, num_layers) self.decoder = LSTMDecoder( - input_size_decoder, - hidden_size, - output_size, - mapped_target_indices, + input_size=input_size_decoder, + hidden_size=hidden_size, + output_size=output_size, + target_indices=mapped_target_indices, num_layers=num_layers, output_sequence_len=output_seq_length, + teacher_force_prob=teacher_force_prob, ) self.encoder_indices = encoder_indices self.decoder_indices = decoder_indices diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index ee69bff0021..8f999a92101 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .autoregression import AutoregressionTask from .base import BaseTask from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask @@ -14,6 +15,7 @@ from .simclr import SimCLRTask __all__ = ( + 'AutoregressionTask', 'BYOLTask', 'BaseTask', 'ClassificationTask', diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 719ce5854f9..43ebe10f9b3 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -5,14 +5,14 @@ from typing import Any -import torch import torch.nn as nn from torch import Tensor from torchmetrics import MetricCollection from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError -from torchvision.models import LSTMSeq2Seq from torchvision.models._api import WeightsEnum +from torchgeo.models import LSTMSeq2Seq + from .base import BaseTask @@ -21,12 +21,15 @@ class AutoregressionTask(BaseTask): def __init__( self, - model: str = 'lstm', + model: str = 'lstm_seq2seq', weights: WeightsEnum | str | bool | None = None, input_size: int = 1, input_size_decoder: int = 1, hidden_size: int = 1, output_size: int = 1, + target_indices: list[int] | None = None, # change this in the model + encoder_indices: list[int] | None = None, + decoder_indices: list[int] | None = None, lookback: int = 3, timesteps_ahead: int = 1, num_layers: int = 1, @@ -37,7 +40,7 @@ def __init__( """Initialize a new AutoregressionTask instance. Args: - model: Name of the model to use, currently supports 'lstm' or 'seq2seq'. + model: Name of the model to use, currently supports 'lstm_seq2seq'. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. @@ -56,24 +59,19 @@ def configure_models(self) -> None: input_size_decoder = self.hparams['input_size_decoder'] hidden_size = self.hparams['hidden_size'] output_size = self.hparams['output_size'] - lookback = self.hparams['lookback'] timesteps_ahead = self.hparams['timesteps_ahead'] num_layers = self.hparams['num_layers'] + target_indices = self.hparams['target_indices'] + encoder_indices = self.hparams['encoder_indices'] + decoder_indices = self.hparams['decoder_indices'] - if model == 'lstm': - assert timesteps_ahead == 1, ( - f'LSTM only supports 1 timestep ahead, got timesteps_ahead={timesteps_ahead}.' - ) - self.model = torch.nn.LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - batch_first=True, - ) - elif model == 'seq2seq': + if model == 'lstm_seq2seq': self.model = LSTMSeq2Seq( input_size_encoder=input_size, input_size_decoder=input_size_decoder, + target_indices=target_indices, + encoder_indices=encoder_indices, + decoder_indices=decoder_indices, hidden_size=hidden_size, output_size=output_size, output_seq_length=timesteps_ahead, @@ -82,7 +80,7 @@ def configure_models(self) -> None: else: raise ValueError( f"Model type '{model}' is not valid. " - "Currently, only supports 'lstm' and 'seq2seq'." + "Currently, only supports 'lstm_seq2seq'." ) def configure_losses(self) -> None: @@ -123,16 +121,18 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: Returns: The loss tensor. """ - x, y = batch - y_hat = self(x) - - loss: Tensor = self.criterion(y_hat, y) + target_indices = self.hparams['target_indices'] + past_steps, future_steps = batch + y_hat = self(past_steps, future_steps) + if target_indices: + future_steps = future_steps[:, :, target_indices] + loss: Tensor = self.criterion(y_hat, future_steps) self.log(f'{stage}_loss', loss) # Retrieve the correct metrics based on the stage metrics = getattr(self, f'{stage}_metrics', None) if metrics: - metrics(y_hat, y) + metrics(y_hat, future_steps) self.log_dict({f'{k}': v for k, v in metrics.compute().items()}) return loss @@ -181,6 +181,9 @@ def predict_step( Returns: Output predicted values. """ - x = batch - y_hat: Tensor = self(x) - return y_hat + past_steps, future_steps = batch + y_hat = self(past_steps, future_steps) + mean = past_steps.mean(dim=0, keepdim=True) + std = past_steps.std(dim=0, keepdim=True) + y_hat_denormalize: Tensor = y_hat*std+mean + return y_hat_denormalize From 537062dcb4e97dd8fa48e156f1727e9f984a341c Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Tue, 11 Mar 2025 22:04:55 +0000 Subject: [PATCH 06/16] more autoregression tests --- tests/trainers/test_autoregression.py | 11 +++++++++++ torchgeo/models/seq2seq.py | 28 +++++++++++++-------------- torchgeo/trainers/autoregression.py | 9 +++++---- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/tests/trainers/test_autoregression.py b/tests/trainers/test_autoregression.py index a054a124fd9..b08ee490705 100644 --- a/tests/trainers/test_autoregression.py +++ b/tests/trainers/test_autoregression.py @@ -7,6 +7,7 @@ from torchgeo.datamodules import MisconfigurationException from torchgeo.main import main +from torchgeo.trainers import AutoregressionTask class TestAutoregressionTask: @@ -36,3 +37,13 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None: main(['predict', *args]) except MisconfigurationException: pass + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + AutoregressionTask(model='invalid_model') + + def test_invalid_loss(self) -> None: + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + AutoregressionTask(loss='invalid_loss') diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index f79d5f93c8d..b57491029a4 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -27,7 +27,7 @@ def __init__( input_size: int, hidden_size: int, output_size: int, - target_indices: list[int], + target_indices: list[int] | None = None, num_layers: int = 1, output_sequence_len: int = 1, teacher_force_prob: float | None = None, @@ -59,7 +59,10 @@ def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: else False ) if not teacher_force: - current_input[:, :, self.target_indices] = output + if self.target_indices: + current_input[:, :, self.target_indices] = output + else: + current_input = output return outputs @@ -69,32 +72,27 @@ def __init__( self, input_size_encoder: int, input_size_decoder: int, - target_indices: list[int], + target_indices: list[int] | None = None, encoder_indices: list[int] | None = None, decoder_indices: list[int] | None = None, hidden_size: int = 1, output_size: int = 1, output_seq_length: int = 1, num_layers: int = 1, - teacher_force_prob: int | None = None, + teacher_force_prob: float | None = None, ) -> None: super().__init__() - # Target indices need to be mapped to the subset of inputs for decoder - mapped_target_indices = ( - torch.nonzero( - torch.isin(torch.tensor(decoder_indices), torch.tensor(target_indices)) - ) - .squeeze() - .tolist() - ) - if not isinstance(mapped_target_indices, list): - mapped_target_indices = [mapped_target_indices] + if decoder_indices and isinstance(target_indices, list): + # Target indices need to be mapped to the subset of inputs for decoder + target_indices = [ + i for i, val in enumerate(decoder_indices) if val in target_indices + ] self.encoder = LSTMEncoder(input_size_encoder, hidden_size, num_layers) self.decoder = LSTMDecoder( input_size=input_size_decoder, hidden_size=hidden_size, output_size=output_size, - target_indices=mapped_target_indices, + target_indices=target_indices, num_layers=num_layers, output_sequence_len=output_seq_length, teacher_force_prob=teacher_force_prob, diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 43ebe10f9b3..05f9288d993 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -9,7 +9,6 @@ from torch import Tensor from torchmetrics import MetricCollection from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError -from torchvision.models._api import WeightsEnum from torchgeo.models import LSTMSeq2Seq @@ -22,12 +21,11 @@ class AutoregressionTask(BaseTask): def __init__( self, model: str = 'lstm_seq2seq', - weights: WeightsEnum | str | bool | None = None, input_size: int = 1, input_size_decoder: int = 1, hidden_size: int = 1, output_size: int = 1, - target_indices: list[int] | None = None, # change this in the model + target_indices: list[int] | None = None, encoder_indices: list[int] | None = None, decoder_indices: list[int] | None = None, lookback: int = 3, @@ -36,6 +34,7 @@ def __init__( loss: str = 'mse', lr: float = 1e-3, patience: int = 10, + teacher_force_prob: float | None = None, ) -> None: """Initialize a new AutoregressionTask instance. @@ -64,6 +63,7 @@ def configure_models(self) -> None: target_indices = self.hparams['target_indices'] encoder_indices = self.hparams['encoder_indices'] decoder_indices = self.hparams['decoder_indices'] + teacher_force_prob = self.hparams['teacher_force_prob'] if model == 'lstm_seq2seq': self.model = LSTMSeq2Seq( @@ -76,6 +76,7 @@ def configure_models(self) -> None: output_size=output_size, output_seq_length=timesteps_ahead, num_layers=num_layers, + teacher_force_prob=teacher_force_prob, ) else: raise ValueError( @@ -185,5 +186,5 @@ def predict_step( y_hat = self(past_steps, future_steps) mean = past_steps.mean(dim=0, keepdim=True) std = past_steps.std(dim=0, keepdim=True) - y_hat_denormalize: Tensor = y_hat*std+mean + y_hat_denormalize: Tensor = y_hat * std + mean return y_hat_denormalize From 03c43720889a299af0758a8c537c8ec01f8e94c6 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Tue, 11 Mar 2025 23:02:03 +0000 Subject: [PATCH 07/16] added more seq2seq tests --- tests/models/test_seq2seq.py | 32 +++++++++++++++++++++++++++++++- torchgeo/models/seq2seq.py | 15 ++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py index 90b729ca40b..847deba5cce 100644 --- a/tests/models/test_seq2seq.py +++ b/tests/models/test_seq2seq.py @@ -9,7 +9,7 @@ BATCH_SIZE = [1, 2, 7] INPUT_SIZE_ENCODER = [1, 3] INPUT_SIZE_DECODER = [2, 3] -OUTPUT_SIZE = [1] +OUTPUT_SIZE = [1, 2, 3] NUM_LAYERS = [1, 2, 3] HIDDEN_SIZE = [1, 2, 3] @@ -87,3 +87,33 @@ def test_hidden_size(self, h: int) -> None: future_steps = torch.randn(batch_size, output_sequence_length, n_features) y = model(past_steps, future_steps) assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + def test_none_indices(self) -> None: + batch_size = 5 + sequence_length = 3 + output_sequence_length = 1 + input_size = 5 + output_size = 1 + model = LSTMSeq2Seq( + input_size_encoder=input_size, input_size_decoder=input_size + ) + past_steps = torch.randn(batch_size, sequence_length, input_size) + future_steps = torch.randn(batch_size, output_sequence_length, input_size) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('o', OUTPUT_SIZE) + def test_output_size(self, o: int) -> None: + batch_size = 5 + sequence_length = 3 + output_sequence_length = 1 + input_size = 5 + model = LSTMSeq2Seq( + input_size_encoder=input_size, input_size_decoder=input_size, output_size=o + ) + past_steps = torch.randn(batch_size, sequence_length, input_size) + future_steps = torch.randn(batch_size, output_sequence_length, input_size) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, o) diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index b57491029a4..e2a55a0b6a5 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -4,7 +4,6 @@ """LSTM Sequence to Sequence (Seq2Seq) Model.""" import random -from typing import cast import torch import torch.nn as nn @@ -82,7 +81,17 @@ def __init__( teacher_force_prob: float | None = None, ) -> None: super().__init__() + for indices, size, name in [ + (encoder_indices, input_size_encoder, 'encoder_indices'), + (decoder_indices, input_size_decoder, 'decoder_indices'), + (target_indices, output_size, 'target_indices'), + ]: + if indices: + assert len(indices) == size, f'Length of {name} should match {size}.' if decoder_indices and isinstance(target_indices, list): + assert set(target_indices).issubset(set(decoder_indices)), ( + 'target_indices should be in decoder_indices.' + ) # Target indices need to be mapped to the subset of inputs for decoder target_indices = [ i for i, val in enumerate(decoder_indices) if val in target_indices @@ -111,5 +120,5 @@ def forward(self, past_steps: Tensor, future_steps: Tensor) -> Tensor: if self.decoder_indices: inputs_decoder = inputs_decoder[:, :, self.decoder_indices] hidden, cell = self.encoder(inputs_encoder) - outputs = self.decoder(inputs_decoder, hidden, cell) - return cast(Tensor, outputs) + outputs: Tensor = self.decoder(inputs_decoder, hidden, cell) + return outputs From a1d55b43004c31217a5a9b92dddfb8456dc1833b Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 12 Mar 2025 21:46:28 +0000 Subject: [PATCH 08/16] added docstrings. --- torchgeo/datamodules/__init__.py | 3 +- torchgeo/datamodules/air_quality.py | 2 +- torchgeo/datasets/air_quality.py | 3 ++ torchgeo/models/seq2seq.py | 68 +++++++++++++++++++++++++++++ torchgeo/trainers/autoregression.py | 33 +++++++++----- 5 files changed, 96 insertions(+), 13 deletions(-) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 86f2d792535..4879d87e259 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -58,7 +58,8 @@ from .xview import XView2DataModule __all__ = ( - 'AirQualityDataModuleAgriFieldNetDataModule', + 'AgriFieldNetDataModule', + 'AirQualityDataModule', 'BaseDataModule', 'BigEarthNetDataModule', 'COWCCountingDataModule', diff --git a/torchgeo/datamodules/air_quality.py b/torchgeo/datamodules/air_quality.py index f4c9127be54..3a43b578887 100644 --- a/torchgeo/datamodules/air_quality.py +++ b/torchgeo/datamodules/air_quality.py @@ -63,7 +63,7 @@ def setup(self, stage: str) -> None: def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int ) -> dict[str, Tensor]: - """Apply batch augmentations to the batch after it is transferred to the device. + """Override base class to avoid applying Kornia augmentations to non-image data. Args: batch: A batch of data that needs to be altered or augmented. diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py index 8d1f77e5e9b..a78a3fc017f 100644 --- a/torchgeo/datasets/air_quality.py +++ b/torchgeo/datasets/air_quality.py @@ -52,6 +52,9 @@ def __init__( Args: root: root directory where dataset can be found download: if True, download dataset and store it in the root directory + num_past_steps: Number of past time steps to use. + num_future_steps: Number of future time steps to use. + Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index e2a55a0b6a5..ef5d38022ed 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -11,16 +11,35 @@ class LSTMEncoder(nn.Module): + """Encoder for LSTM Seq2Seq.""" + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1) -> None: + """Initialize a new LSTMEncoder. + + Args: + input_size: The number of features in the input. + hidden_size: The number of features in the hidden state. + num_layers: The number of LSTM layers. + """ super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass of the encoder. + + Args: + x: Input sequence of shape (b, sequence length, input_size). + + Returns: + Hidden and cell states. + """ _, (hidden, cell) = self.lstm(x) return hidden, cell class LSTMDecoder(nn.Module): + """Decoder for LSTM Seq2Seq.""" + def __init__( self, input_size: int, @@ -31,6 +50,19 @@ def __init__( output_sequence_len: int = 1, teacher_force_prob: float | None = None, ) -> None: + """Initialize a new LSTMDecoder. + + Args: + input_size: The number of features in the input. + hidden_size: The number of features in the hidden state. + output_size: The number of features output by the decoder. + target_indices: Indices of the target features in the dataset. + If None, uses all features passed to the decoder. Defaults to None. + num_layers: Number of LSTM layers. Defaults to 1. + output_sequence_len: The number of steps to predict forward. Defaults to 1. + teacher_force_prob: Probability of using teacher forcing. If None, does not + use teacher forcing. Defaults to None. + """ super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) @@ -40,6 +72,16 @@ def __init__( self.teacher_force_prob = teacher_force_prob def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: + """Forward pass of the decoder. + + Args: + inputs: Input sequence of shape (b, sequence length, input_size). + hidden: hidden state from the encoder. + cell: cell state from the encoder. + + Returns: + Output sequence of shape (b, output_sequence_len, output_size). + """ batch_size = inputs.shape[0] outputs = torch.zeros(batch_size, self.output_sequence_len, self.output_size) @@ -67,6 +109,8 @@ def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: class LSTMSeq2Seq(nn.Module): + """LSTM Sequence-to-Sequence (Seq2Seq).""" + def __init__( self, input_size_encoder: int, @@ -80,6 +124,21 @@ def __init__( num_layers: int = 1, teacher_force_prob: float | None = None, ) -> None: + """Initialize a new LSTMSeq2Seq model. + + Args: + input_size_encoder: The number of features in the encoder input. + input_size_decoder: The number of features in the decoder input. + target_indices: The indices of the target(s) in the dataset. If None, uses all features. Defaults to None. + encoder_indices: The indices of the encoder inputs. If None, uses all features. Defaults to None. + decoder_indices: The indices of the decoder inputs. If None, uses all features. Defaults to None. + hidden_size: The number of features in the hidden states of the encoder and decoder. Defaults to 1. + output_size: The number of features output by the model. Defaults to 1. + output_seq_length: The number of steps to predict forward. Defaults to 1. + num_layers: Number of LSTM layers in the encoder and decoder. Defaults to 1. + teacher_force_prob: Probability of using teacher forcing. If None, does not + use teacher forcing. Defaults to None. + """ super().__init__() for indices, size, name in [ (encoder_indices, input_size_encoder, 'encoder_indices'), @@ -110,6 +169,15 @@ def __init__( self.decoder_indices = decoder_indices def forward(self, past_steps: Tensor, future_steps: Tensor) -> Tensor: + """Forward pass of the model. + + Args: + past_steps: Past time steps. + future_steps: Future time steps. + + Returns: + Output sequence of shape (b, output_seq_length, output_size). + """ if self.encoder_indices: inputs_encoder = past_steps[:, :, self.encoder_indices] else: diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 05f9288d993..817c7f70df9 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -28,7 +28,6 @@ def __init__( target_indices: list[int] | None = None, encoder_indices: list[int] | None = None, decoder_indices: list[int] | None = None, - lookback: int = 3, timesteps_ahead: int = 1, num_layers: int = 1, loss: str = 'mse', @@ -40,14 +39,23 @@ def __init__( Args: model: Name of the model to use, currently supports 'lstm_seq2seq'. - weights: Initial model weights. Either a weight enum, the string - representation of a weight enum, True for ImageNet weights, False - or None for random weights, or the path to a saved model state dict. - loss: One of 'mse' or 'mae'. - lr: Learning rate for optimizer. - patience: Patience for learning rate scheduler. - - .. versionadded: 0.7 + Defaults to 'lstm_seq2seq'. + input_size: The number of features in the input. Defaults to 1. + input_size_decoder: The number of features in the decoder input. + Defaults to 1. + hidden_size: The number of features in the hidden states of the encoder + and decoder. Defaults to 1. + output_size: The number of features output by the model. Defaults to 1. + target_indices: The indices of the target(s) in the dataset. If None, uses all features. Defaults to None. + encoder_indices: The indices of the encoder inputs. If None, uses all features. Defaults to None. + decoder_indices: The indices of the decoder inputs. If None, uses all features. Defaults to None. + timesteps_ahead: Number of time steps to predict. Defaults to 1. + num_layers: Number of LSTM layers in the encoder and decoder. Defaults to 1. + loss: One of 'mse' or 'mae'. Defaults to 'mse'. + lr: Learning rate for optimizer. Defaults to 1e-3. + patience: Patience for learning rate scheduler. Defaults to 10. + teacher_force_prob: Probability of using teacher forcing. If None, does not + use teacher forcing. Defaults to None. """ super().__init__() @@ -92,10 +100,13 @@ def configure_losses(self) -> None: """ loss: str = self.hparams['loss'] if loss == 'mse': - self.criterion = nn.MSELoss() + self.criterion: nn.Module = nn.MSELoss() + elif loss == 'mae': + self.criterion = nn.L1Loss() else: raise ValueError( - f"Loss type '{loss}' is not valid. Currently, supports 'mse' loss." + f"Loss type '{loss}' is not valid. " + "Currently, supports 'mse' or 'mae' loss." ) def configure_metrics(self) -> None: From 382f72f62231e7c1f416234ab8806b487608e79d Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 12 Mar 2025 22:01:45 +0000 Subject: [PATCH 09/16] added air quality dataset to docs --- docs/api/datasets.rst | 5 +++++ docs/api/datasets/non_geo_datasets.csv | 1 + 2 files changed, 6 insertions(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index c60b08f6666..68d53a80008 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -202,6 +202,11 @@ ADVANCE .. autoclass:: ADVANCE +Air Quality +^^^^^^^^^^^ + +.. autoclass:: AirQuality + Benin Cashew Plantations ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index f91f6b0e967..7b622b8e969 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -1,5 +1,6 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ADVANCE`_,C,"Google Earth, Freesound","CC-BY-4.0","5,075",13,512x512,0.5,RGB +`Air Quality`_,"R,T","UCI Machine Learning Repository","CC-BY-4.0","9,358",,,, `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" From 075be68d93791e754c4812e4b04771404cf79d31 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 12 Mar 2025 22:34:02 +0000 Subject: [PATCH 10/16] make variable name consistent --- tests/models/test_seq2seq.py | 6 +++--- torchgeo/models/seq2seq.py | 8 ++++---- torchgeo/trainers/autoregression.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py index 847deba5cce..7b1ea34ffb2 100644 --- a/tests/models/test_seq2seq.py +++ b/tests/models/test_seq2seq.py @@ -31,7 +31,7 @@ def test_input_size(self, b: int, e: int, d: int) -> None: encoder_indices=list(range(0, e)), decoder_indices=list(range(0, d)), output_size=output_size, - output_seq_length=output_sequence_length, + output_sequence_len=output_sequence_length, ) past_steps = torch.randn(b, sequence_length, n_features) future_steps = torch.randn(b, output_sequence_length, n_features) @@ -55,7 +55,7 @@ def test_num_layers(self, n: int) -> None: encoder_indices=list(range(0, input_size_encoder)), decoder_indices=list(range(0, input_size_decoder)), output_size=output_size, - output_seq_length=output_sequence_length, + output_sequence_len=output_sequence_length, num_layers=n, ) past_steps = torch.randn(batch_size, sequence_length, n_features) @@ -80,7 +80,7 @@ def test_hidden_size(self, h: int) -> None: encoder_indices=list(range(0, input_size_encoder)), decoder_indices=list(range(0, input_size_decoder)), output_size=output_size, - output_seq_length=output_sequence_length, + output_sequence_len=output_sequence_length, hidden_size=h, ) past_steps = torch.randn(batch_size, sequence_length, n_features) diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py index ef5d38022ed..b3ccb79ff05 100644 --- a/torchgeo/models/seq2seq.py +++ b/torchgeo/models/seq2seq.py @@ -120,7 +120,7 @@ def __init__( decoder_indices: list[int] | None = None, hidden_size: int = 1, output_size: int = 1, - output_seq_length: int = 1, + output_sequence_len: int = 1, num_layers: int = 1, teacher_force_prob: float | None = None, ) -> None: @@ -134,7 +134,7 @@ def __init__( decoder_indices: The indices of the decoder inputs. If None, uses all features. Defaults to None. hidden_size: The number of features in the hidden states of the encoder and decoder. Defaults to 1. output_size: The number of features output by the model. Defaults to 1. - output_seq_length: The number of steps to predict forward. Defaults to 1. + output_sequence_len: The number of steps to predict forward. Defaults to 1. num_layers: Number of LSTM layers in the encoder and decoder. Defaults to 1. teacher_force_prob: Probability of using teacher forcing. If None, does not use teacher forcing. Defaults to None. @@ -162,7 +162,7 @@ def __init__( output_size=output_size, target_indices=target_indices, num_layers=num_layers, - output_sequence_len=output_seq_length, + output_sequence_len=output_sequence_len, teacher_force_prob=teacher_force_prob, ) self.encoder_indices = encoder_indices @@ -176,7 +176,7 @@ def forward(self, past_steps: Tensor, future_steps: Tensor) -> Tensor: future_steps: Future time steps. Returns: - Output sequence of shape (b, output_seq_length, output_size). + Output sequence of shape (b, output_sequence_len, output_size). """ if self.encoder_indices: inputs_encoder = past_steps[:, :, self.encoder_indices] diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 817c7f70df9..8ec91d44c48 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -82,7 +82,7 @@ def configure_models(self) -> None: decoder_indices=decoder_indices, hidden_size=hidden_size, output_size=output_size, - output_seq_length=timesteps_ahead, + output_sequence_len=timesteps_ahead, num_layers=num_layers, teacher_force_prob=teacher_force_prob, ) From 6ed328506e6265dd286dba7fa367c157897d9aae Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 12 Mar 2025 23:05:41 +0000 Subject: [PATCH 11/16] yaml format --- tests/conf/air_quality.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conf/air_quality.yaml b/tests/conf/air_quality.yaml index 5ef4da2c2c2..2723681b299 100644 --- a/tests/conf/air_quality.yaml +++ b/tests/conf/air_quality.yaml @@ -13,4 +13,4 @@ data: init_args: batch_size: 2 dict_kwargs: - root: 'tests/data/air_quality' \ No newline at end of file + root: 'tests/data/air_quality' From 46bed025cfecbfc487f5dead48c9fb762ca1726a Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 12 Mar 2025 23:12:25 +0000 Subject: [PATCH 12/16] fixed air quality dataset tests --- tests/datasets/test_air_quality.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py index 86ef3b45cf8..e8181739266 100644 --- a/tests/datasets/test_air_quality.py +++ b/tests/datasets/test_air_quality.py @@ -3,10 +3,10 @@ from pathlib import Path -import pandas as pd import pytest from _pytest.fixtures import SubRequest from pytest import MonkeyPatch +from torch import Tensor from torchgeo.datasets import AirQuality, DatasetNotFoundError @@ -22,12 +22,12 @@ def dataset( def test_getitem(self, dataset: AirQuality) -> None: x, y = dataset[0] - assert isinstance(x, pd.DataFrame) - assert len(x.columns) == 15 - assert len(x) == dataset.past_steps - assert isinstance(y, pd.DataFrame) - assert len(y.columns) == 15 - assert len(y) == dataset.future_steps + assert isinstance(x, Tensor) + assert x.shape[1] == 15 + assert x.shape[0] == dataset.num_past_steps + assert isinstance(y, Tensor) + assert y.shape[1] == 15 + assert y.shape[0] == dataset.num_future_steps def test_len(self, dataset: AirQuality) -> None: assert len(dataset) == 46 From 3de841c3ee13e1009e3b9826df00c5d68eddc151 Mon Sep 17 00:00:00 2001 From: Keenan Eves <31701650+keves1@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:29:04 -0600 Subject: [PATCH 13/16] Apply suggestions from code review Co-authored-by: Adam J. Stewart --- tests/datasets/test_air_quality.py | 2 +- torchgeo/datasets/air_quality.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py index e8181739266..6c20cc84866 100644 --- a/tests/datasets/test_air_quality.py +++ b/tests/datasets/test_air_quality.py @@ -16,7 +16,7 @@ class TestAirQuality: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> AirQuality: - url = 'tests/data/air_quality/data.csv' + url = os.path.join('tests', 'data', 'air_quality', 'data.csv') monkeypatch.setattr(AirQuality, 'url', url) return AirQuality(tmp_path, download=True) diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py index a78a3fc017f..5d08d2c9e8a 100644 --- a/torchgeo/datasets/air_quality.py +++ b/torchgeo/datasets/air_quality.py @@ -113,4 +113,4 @@ def _load_data(self) -> pd.DataFrame: raise DatasetNotFoundError(self) # Download the dataset - return pd.read_csv(self.url) + return pd.read_csv(self.url, na_values=-200) From 022000b1d3a85521c60dea1693920c9d29b9e14c Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 9 Apr 2025 16:21:18 +0000 Subject: [PATCH 14/16] change __getitem__ to have same return type as base class --- tests/datasets/test_air_quality.py | 5 ++++- torchgeo/datasets/air_quality.py | 5 +++-- torchgeo/trainers/autoregression.py | 3 ++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py index 6c20cc84866..18294fe3842 100644 --- a/tests/datasets/test_air_quality.py +++ b/tests/datasets/test_air_quality.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os from pathlib import Path import pytest @@ -21,7 +22,9 @@ def dataset( return AirQuality(tmp_path, download=True) def test_getitem(self, dataset: AirQuality) -> None: - x, y = dataset[0] + item = dataset[0] + x = item['past'] + y = item['future'] assert isinstance(x, Tensor) assert x.shape[1] == 15 assert x.shape[0] == dataset.num_past_steps diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py index 5d08d2c9e8a..27c1425e2f0 100644 --- a/torchgeo/datasets/air_quality.py +++ b/torchgeo/datasets/air_quality.py @@ -4,6 +4,7 @@ """Air Quality dataset.""" import os +from typing import Any import pandas as pd import torch @@ -72,7 +73,7 @@ def __len__(self) -> int: """ return len(self.data) - (self.num_past_steps + self.num_future_steps) - def __getitem__(self, index: int) -> pd.Series: + def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: @@ -95,7 +96,7 @@ def __getitem__(self, index: int) -> pd.Series: past_steps_normalized = (past_steps - mean) / (std + 1e-12) future_steps_normalized = (future_steps - mean) / (std + 1e-12) - return past_steps_normalized, future_steps_normalized + return {'past': past_steps_normalized, 'future': future_steps_normalized} def _load_data(self) -> pd.DataFrame: """Load the dataset into a pandas dataframe. diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 8ec91d44c48..e00349e8b8d 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -134,7 +134,8 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: The loss tensor. """ target_indices = self.hparams['target_indices'] - past_steps, future_steps = batch + past_steps = batch['past'] + future_steps = batch['future'] y_hat = self(past_steps, future_steps) if target_indices: future_steps = future_steps[:, :, target_indices] From 40676ab886e42fbe490c2ded71a712f881501dd4 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 10 Apr 2025 17:58:04 +0000 Subject: [PATCH 15/16] using kwargs for additional model arguments --- torchgeo/trainers/autoregression.py | 38 +++-------------------------- 1 file changed, 4 insertions(+), 34 deletions(-) diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index e00349e8b8d..4967746a649 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -23,17 +23,10 @@ def __init__( model: str = 'lstm_seq2seq', input_size: int = 1, input_size_decoder: int = 1, - hidden_size: int = 1, - output_size: int = 1, - target_indices: list[int] | None = None, - encoder_indices: list[int] | None = None, - decoder_indices: list[int] | None = None, - timesteps_ahead: int = 1, - num_layers: int = 1, loss: str = 'mse', lr: float = 1e-3, patience: int = 10, - teacher_force_prob: float | None = None, + **kwargs: dict[str, Any], ) -> None: """Initialize a new AutoregressionTask instance. @@ -43,20 +36,12 @@ def __init__( input_size: The number of features in the input. Defaults to 1. input_size_decoder: The number of features in the decoder input. Defaults to 1. - hidden_size: The number of features in the hidden states of the encoder - and decoder. Defaults to 1. - output_size: The number of features output by the model. Defaults to 1. - target_indices: The indices of the target(s) in the dataset. If None, uses all features. Defaults to None. - encoder_indices: The indices of the encoder inputs. If None, uses all features. Defaults to None. - decoder_indices: The indices of the decoder inputs. If None, uses all features. Defaults to None. - timesteps_ahead: Number of time steps to predict. Defaults to 1. - num_layers: Number of LSTM layers in the encoder and decoder. Defaults to 1. loss: One of 'mse' or 'mae'. Defaults to 'mse'. lr: Learning rate for optimizer. Defaults to 1e-3. patience: Patience for learning rate scheduler. Defaults to 10. - teacher_force_prob: Probability of using teacher forcing. If None, does not - use teacher forcing. Defaults to None. + **kwargs: Additional keyword arguments passed to the model. """ + self.kwargs: dict[str, Any] = kwargs super().__init__() def configure_models(self) -> None: @@ -64,27 +49,12 @@ def configure_models(self) -> None: model: str = self.hparams['model'] input_size = self.hparams['input_size'] input_size_decoder = self.hparams['input_size_decoder'] - hidden_size = self.hparams['hidden_size'] - output_size = self.hparams['output_size'] - timesteps_ahead = self.hparams['timesteps_ahead'] - num_layers = self.hparams['num_layers'] - target_indices = self.hparams['target_indices'] - encoder_indices = self.hparams['encoder_indices'] - decoder_indices = self.hparams['decoder_indices'] - teacher_force_prob = self.hparams['teacher_force_prob'] if model == 'lstm_seq2seq': self.model = LSTMSeq2Seq( input_size_encoder=input_size, input_size_decoder=input_size_decoder, - target_indices=target_indices, - encoder_indices=encoder_indices, - decoder_indices=decoder_indices, - hidden_size=hidden_size, - output_size=output_size, - output_sequence_len=timesteps_ahead, - num_layers=num_layers, - teacher_force_prob=teacher_force_prob, + **self.kwargs, ) else: raise ValueError( From 0be181bfd3058450d1dca762f2f580956db518eb Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 10 Apr 2025 22:43:25 +0000 Subject: [PATCH 16/16] denormalize data before calculating metrics and for predictions --- tests/trainers/test_autoregression.py | 10 ++++++++++ torchgeo/datasets/air_quality.py | 7 ++++++- torchgeo/trainers/autoregression.py | 21 +++++++++++++++++---- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/trainers/test_autoregression.py b/tests/trainers/test_autoregression.py index b08ee490705..fbd91f791eb 100644 --- a/tests/trainers/test_autoregression.py +++ b/tests/trainers/test_autoregression.py @@ -4,6 +4,7 @@ import os import pytest +import torch from torchgeo.datamodules import MisconfigurationException from torchgeo.main import main @@ -47,3 +48,12 @@ def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): AutoregressionTask(loss='invalid_loss') + + def test_denormalize(self) -> None: + data = torch.rand(1, 3, 1) + mean = data.mean(dim=1, keepdim=True) + std = data.std(dim=1, keepdim=True) + data_normalized = (data - mean) / std + trainer = AutoregressionTask() + denorm = trainer._denormalize(data_normalized, mean, std) + assert torch.equal(data, denorm) diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py index 27c1425e2f0..d1101fa93f4 100644 --- a/torchgeo/datasets/air_quality.py +++ b/torchgeo/datasets/air_quality.py @@ -96,7 +96,12 @@ def __getitem__(self, index: int) -> dict[str, Any]: past_steps_normalized = (past_steps - mean) / (std + 1e-12) future_steps_normalized = (future_steps - mean) / (std + 1e-12) - return {'past': past_steps_normalized, 'future': future_steps_normalized} + return { + 'past': past_steps_normalized, + 'future': future_steps_normalized, + 'mean': mean, + 'std': std, + } def _load_data(self) -> pd.DataFrame: """Load the dataset into a pandas dataframe. diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py index 4967746a649..fbbebb42516 100644 --- a/torchgeo/trainers/autoregression.py +++ b/torchgeo/trainers/autoregression.py @@ -23,6 +23,7 @@ def __init__( model: str = 'lstm_seq2seq', input_size: int = 1, input_size_decoder: int = 1, + output_size: int = 1, loss: str = 'mse', lr: float = 1e-3, patience: int = 10, @@ -36,6 +37,7 @@ def __init__( input_size: The number of features in the input. Defaults to 1. input_size_decoder: The number of features in the decoder input. Defaults to 1. + output_size: The number of features output by the model. Defaults to 1. loss: One of 'mse' or 'mae'. Defaults to 'mse'. lr: Learning rate for optimizer. Defaults to 1e-3. patience: Patience for learning rate scheduler. Defaults to 10. @@ -112,6 +114,12 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: loss: Tensor = self.criterion(y_hat, future_steps) self.log(f'{stage}_loss', loss) + # Denormalize the data before computing metrics + if all(key in batch for key in ['mean', 'std']): + mean = batch['mean'][:, :, target_indices] + std = batch['std'][:, :, target_indices] + y_hat = self._denormalize(y_hat, mean, std) + future_steps = self._denormalize(future_steps, mean, std) # Retrieve the correct metrics based on the stage metrics = getattr(self, f'{stage}_metrics', None) if metrics: @@ -164,9 +172,14 @@ def predict_step( Returns: Output predicted values. """ - past_steps, future_steps = batch + target_indices = self.hparams['target_indices'] + past_steps = batch['past'] + future_steps = batch['future'] y_hat = self(past_steps, future_steps) - mean = past_steps.mean(dim=0, keepdim=True) - std = past_steps.std(dim=0, keepdim=True) - y_hat_denormalize: Tensor = y_hat * std + mean + mean = batch['mean'][:, :, target_indices] + std = batch['std'][:, :, target_indices] + y_hat_denormalize: Tensor = self._denormalize(y_hat, mean, std) return y_hat_denormalize + + def _denormalize(self, data: Tensor, mean: Tensor, std: Tensor) -> Tensor: + return data * std + mean