diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8a3f501ae82..a801e6659ed 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -216,6 +216,11 @@ ADVANCE .. autoclass:: ADVANCE +Air Quality +^^^^^^^^^^^ + +.. autoclass:: AirQuality + Benin Cashew Plantations ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 2cb2607c53c..5f1c7091388 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -1,5 +1,6 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ADVANCE`_,C,"Google Earth, Freesound","CC-BY-4.0","5,075",13,512x512,0.5,RGB +`Air Quality`_,"R,T","UCI Machine Learning Repository","CC-BY-4.0","9,358",,,, `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" diff --git a/tests/conf/air_quality.yaml b/tests/conf/air_quality.yaml new file mode 100644 index 00000000000..2723681b299 --- /dev/null +++ b/tests/conf/air_quality.yaml @@ -0,0 +1,16 @@ +model: + class_path: AutoregressionTask + init_args: + loss: 'mse' + model: 'lstm_seq2seq' + input_size: 3 + input_size_decoder: 1 + target_indices: [2] + encoder_indices: [2, 12, 13] + decoder_indices: [2] +data: + class_path: AirQualityDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/air_quality' diff --git a/tests/data/air_quality/data.csv b/tests/data/air_quality/data.csv new file mode 100644 index 00000000000..789142c6274 --- /dev/null +++ b/tests/data/air_quality/data.csv @@ -0,0 +1,51 @@ +Date,Time,CO(GT),PT08.S1(CO),NMHC(GT),C6H6(GT),PT08.S2(NMHC),NOx(GT),PT08.S3(NOx),NO2(GT),PT08.S4(NO2),PT08.S5(O3),T,RH,AH +0.428844960315107,0.9866608945950183,0.3705128349758361,0.06370825832902971,0.9984071519066181,0.23601028428833615,0.6711387944567729,0.7165114438855709,0.5840507533339877,0.25990180009319086,0.571315476202064,0.02832021978464172,0.037552099008017814,0.9816954186918593,0.6350429315448495 +0.7775475150998155,0.6659969353072417,0.39932933609779664,0.2622298457856952,0.9443611682964963,0.20875986148029524,0.32574791761079014,0.7862509708476118,0.6908314932342879,0.8361516998619839,0.6121843012493646,0.3798949072266635,0.9568554958054666,0.02607922923897077,0.635026275140384 +0.9039645671894101,0.897744879713676,0.2881787242258511,0.002372804476408641,0.9730300845182906,0.09010979656394669,0.09385852352265911,0.8314209560340037,0.3113554033328697,0.35221348466336844,0.019181890067545115,0.7108240805112648,0.6325030745529099,0.2536998130045216,0.695415273113407 +0.3059428706018087,0.857795345507438,0.16818645291766787,0.6665326676398685,0.34171307776576454,0.028318962518144142,0.966118039297472,0.6900530586626799,0.44300982245907516,0.8091369572525287,0.11479260604867825,0.6455647318157319,0.5324770086205174,0.0481917296021982,0.633141760683114 +0.35765564697096386,0.7858053743605675,0.3104016522107257,0.493416228005473,0.1483762812702656,0.7241500129800955,0.015896064761218742,0.44323625781500997,0.22656208552155188,0.5764696354888218,0.7836345188253819,0.395272511563161,0.8211200761625238,0.8825035849054379,0.7022012756248662 +0.5223180133243563,0.19032196352865638,0.2376350265360937,0.02801444954355481,0.7262114396783,0.4855480945003384,0.22914947905749994,0.6105985226756249,0.4299963391209827,0.016140740391492603,0.6927601581422852,0.34532464557082765,0.7054955160261106,0.6635612542374121,0.950276802833115 +0.6335493777298301,0.9990439551776903,0.8576504193950082,0.10799687272757574,0.6728234503646858,0.6984180171198154,0.36591757674293135,0.3927919402783018,0.48583164599457185,0.27880177963877484,0.8086519930928119,0.958232169659469,0.046577978803069686,0.9229631695109645,0.7761290479705539 +0.7534312680544961,0.17968083857978556,0.9164541915202609,0.9642421245611333,0.9165621403348179,0.28883554395073907,0.542156663510998,0.7912965980953718,0.892003147870078,0.17467182599762165,0.10036895690809988,0.292028304693698,0.41626796824868517,0.3291941225945081,0.004793952836596116 +0.4030229293020441,0.6241895804028355,0.8027489708177322,0.6061090924552509,0.49517547297881426,0.842015882546139,0.7996783050301155,0.7231551575961693,0.5605444237938936,0.725715256328619,0.8383988129619468,0.23793853519650288,0.7475719718283232,0.5186515383181854,0.8065088529876593 +0.9588023269759194,0.1754985026827054,0.5489916656547322,0.0422015379054006,0.33580950564507894,0.060679127163351776,0.8765016570612999,0.6035308449183396,0.24078472469497103,0.7597171836488161,0.48912582023988704,0.2453476539171492,0.7102981668858297,0.8772506412475322,0.8534851164967973 +0.4846229396514261,0.9436762481111945,0.06741844190243473,0.4618005148984925,0.5900581686288878,0.16831152439252683,0.7878945078391524,0.5360733064490945,0.6211485399827885,0.17500709594134145,0.697960750115879,0.13581109878852693,0.8156566690374971,0.242506637688075,0.48633996290588954 +0.3483524375314333,0.04300540733448299,0.9647797491796986,0.4613579175699183,0.9693808467706223,0.46749022608964597,0.5655973453407719,0.9873213674954243,0.8736119423601583,0.4237783071664204,0.5080632099038275,0.8632283365136613,0.5264136613838126,0.0434876036352726,0.3081429927207596 +0.6326611204832632,0.09147711621722876,0.6406961675831314,0.5248086249876812,0.772523573112128,0.7113592876716782,0.24915772018411464,0.9654421446013526,0.0032685908574943134,0.9364324001873182,0.9941463757396615,0.4556472971512966,0.9401087495394862,0.56161863415598,0.8121018570128568 +0.6063620622055907,0.09601965532476875,0.5631920459131398,0.9597996401070554,0.805223309981179,0.9110155339375118,0.90635556246564,0.629883065213367,0.6992874629069337,0.02755995522976451,0.6764407089152157,0.5147771063157597,0.08980091589916439,0.6468489058489089,0.4778744096276797 +0.7124181751859203,0.38112767817384985,0.06180169829266957,0.644533493507947,0.9736381656059738,0.2217158561329341,0.5807146315430849,0.29137729741005947,0.6000551984650088,0.11915249772454506,0.06507451919960028,0.9144070859628508,0.463730190931789,0.7364119603627168,0.8299907778984987 +0.7673507410857925,0.09015260397470759,0.3565636743608366,0.30149039227455776,0.5823222840186881,0.31774569361446403,0.31262666726144317,0.8919040873509467,0.9483652005692642,0.20486222460576364,0.5907118699600977,0.39700028234544493,0.8806661531424751,0.9096550586000391,0.9926732255621498 +0.28179392984634344,0.13208638219433366,0.46650838734796995,0.9693186846197664,0.9116492020343615,0.011169400499041582,0.7921130594859067,0.01010787552165282,0.23871477464776114,0.5039327493738484,0.4694618944757426,0.3320929088055071,0.9953005407830204,0.41721458831109315,0.6219705979188263 +0.5155778085824996,0.43864518002279873,0.3823433993193788,0.7205316487464971,0.6883140093469334,0.7174177831551652,0.6644014203675569,0.7320462334354929,0.44604977554236236,0.5925488379113077,0.053334404808935476,0.20635224895832338,0.9983571366293664,0.2462670173747903,0.32992146523553945 +0.19311061229302073,0.49177543376099964,0.8803946805135382,0.9942427746247304,0.4127104725455135,0.2855126643140573,0.20041244706800454,0.8332085753072233,0.5162171465890477,0.41722641817201556,0.2889597261508776,0.6453434231229928,0.048873272538541124,0.9352022597778659,0.7490170642965864 +0.023804537011409388,0.452960514401231,0.5316856419919115,0.8474423398747712,0.3212592954043607,0.04995160092661344,0.5335128741135348,0.9731243839111817,0.5646818375999015,0.2342174425568424,0.33523863282203825,0.6017411408124446,0.24632459841924303,0.9761407637655526,0.22339309515335026 +0.7710725412919975,0.1081915279745359,0.18287157220528694,0.03897648619204641,0.19564388113623565,0.7173265323695658,0.4534584642536357,0.25599394289870114,0.5055224980383046,0.5337342367427621,0.6434627637452554,0.395112816109344,0.29283240378518904,0.6734206006274873,0.5477890808804488 +0.5810843118756917,0.24701492777967693,0.021457590754809575,0.07930388248379416,0.8573797694607035,0.5849167765719225,0.20752144651399862,0.30720474523817765,0.901116555870858,0.7310787439353938,0.04149558398673425,0.45048315860983934,0.5355237069576362,0.935614982156612,0.14628978197266007 +0.9650881158768077,0.24393127671354498,0.023551931104118684,0.9276372334757385,0.646480637788412,0.8768325666870731,0.5031423256890551,0.6703405625328099,0.4240248411230091,0.8750823470850183,0.1521392487440415,0.8360195677538244,0.0029591011663731015,0.48700328371501844,0.5271555877737016 +0.4435095507554402,0.2580945634985652,0.05043531826353309,0.7485412853110855,0.36737530655511386,0.6279603473688744,0.7233713190729659,0.2873250341885155,0.9586053373211195,0.1919087197200391,0.5004314129592653,0.6978240080356658,0.5577517576652488,0.6386739524021738,0.9649284721294775 +0.9376060555213641,0.4731932398724885,0.5707952101562018,0.6772188964951886,0.9326699033686338,0.670545660770911,0.9382693295337488,0.9703174731409802,0.9330110811684368,0.14772715375852952,0.428214789686726,0.6993029816523798,0.5437099013249579,0.6446790166959826,0.3573838928746298 +0.5707099320299277,0.6105390935028076,0.8931108905714683,0.025771783679303994,0.7635038685554232,0.8565736240255146,0.7800324842310027,0.9429786595592813,0.8059731070278511,0.5879019395339954,0.607668827673365,0.9277821169731242,0.6723523734532479,0.0614473211053469,0.5299835659114396 +0.6542688458611743,0.41464460991830965,0.7729402924763508,0.34850320480829167,0.6262491120093998,0.2155710985275111,0.030447723696382156,0.4262638797185796,0.1566159218170904,0.04011593602983754,0.7468913828855264,0.35360126642453826,0.7406503827000928,0.362892362180792,0.39937108720089 +0.34814888816674316,0.371317694189513,0.5816652706959554,0.019843056042100238,0.23720370216161712,0.984938638285703,0.7292516592931112,0.902860667541096,0.30361097435474227,0.07043995932401148,0.00043250785765291955,0.3067321666735895,0.15503703704008376,0.4508939658276898,0.04422635705802691 +0.8013707630547462,0.4543895849721282,0.0878161993910207,0.7105661220120499,0.04291495984566385,0.03504390871446594,0.6709211024477577,0.22647570810062134,0.11262041875102746,0.043594084742591854,0.007320695592324067,0.2994001194857965,0.08934592081853354,0.6230726999829022,0.37832109880054654 +0.5926074012008012,0.21392691953315235,0.8311965562385258,0.20794116094371973,0.9196929513858264,0.21814349042248193,0.5073730808244645,0.6900149374774187,0.5554936110040637,0.0750903743733945,0.047030666967787904,0.9136331372467751,0.5485772707843245,0.5408508424428177,0.024359667511356986 +0.09520205248950353,0.5738344823681878,0.2949582515448169,0.48167057889012976,0.016506296051865488,0.9089393162718294,0.2566553321319921,0.5007944393018618,0.448992249441877,0.13231445877495374,0.6322972779427781,0.9733466533581665,0.2165501262249061,0.9568213178510563,0.28110652348139475 +0.8237066669097532,0.8471165746933589,0.18099919523734642,0.49240791194427347,0.06259737672750787,0.8687695915389871,0.4103852765282282,0.43051340316415043,0.4450826317417602,0.33431413370180796,0.7366931702446939,0.37023667782102987,0.36997671769412943,0.1211605644295436,0.45067417487640593 +0.9634590416861293,0.7858559518135985,0.74374331201771,0.4007820629942831,0.12760838775072902,0.4873951988693582,0.8382617560242147,0.25853447078049074,0.7062310003012338,0.6529408405729801,0.015096677702215677,0.9834982226646097,0.9790279136470549,0.6189204851779193,0.09843686931495055 +0.20756016119873177,0.5255535033514623,0.922524228081515,0.9524167195114229,0.7208207959866236,0.6263358071157747,0.9987651489461044,0.14252637081286645,0.3206156120113489,0.016239850655199173,0.4041479123760445,0.4801394852977854,0.7873993000519485,0.5789247169843984,0.908330878080155 +0.9449101753483371,0.9141091946660237,0.7353219908683188,0.6186834256394931,0.08208508579970453,0.10415261874375192,0.33033775801745013,0.8759942133922273,0.0754632293146652,0.39202430509741515,0.84948765696903,0.9369570239708863,0.7085089305846813,0.6394840001424527,0.5575563370878096 +0.8226919222111803,0.13187541224619426,0.4029478235262133,0.37927278505901063,0.4878771082623834,0.6476135656960138,0.4725905999089074,0.4820870534393096,0.394659598909027,0.33975182879399257,0.18084391036773562,0.2644746877299159,0.5847563968324464,0.5488046147210136,0.361241923043361 +0.6410311124370097,0.3497973529679822,0.002029420252289582,0.057120609545241785,0.8233519178931198,0.1486615196663975,0.3452357562689956,0.12855397342294383,0.1523238518367891,0.3637841435095427,0.37601442526675244,0.3466294898364505,0.46350955018604933,0.40981396586484764,0.9547491276934735 +0.7557995782420134,0.13938166018540565,0.8980747918079001,0.7475132607772889,0.9518481897309593,0.2792674103972418,0.24421651437977954,0.7440461333508422,0.3618842528331422,0.6656265553762212,0.2523816472721431,0.867421348926042,0.14146798447125164,0.034574521224249755,0.7648777623459267 +0.3842031181002018,0.9259401519433742,0.6038248606248895,0.18101150278973477,0.18493999595610444,0.7445627591713837,0.8246774369943273,0.9219681128803421,0.03163403931095876,0.48852950665712,0.9047015771978811,0.8365982676726381,0.9837997154674301,0.4431720967406755,0.3712699198717241 +0.1154678303876755,0.9748501368471806,0.8908722432221715,0.39002206687271523,0.648388907054587,0.6516268144864886,0.5060133488896946,0.6812514789142452,0.8579070634182451,0.19584068247011188,0.28662709488402704,0.7846868939029139,0.4955990004056613,0.5354020453253707,0.6188554810393168 +0.9483077066658636,0.4430884055243558,0.6924713594940829,0.48956882061910645,0.521336203627758,0.6476969423503891,0.7709252172905698,0.12381864993945102,0.026355814158331103,0.9150658700590858,0.8965855290857291,0.21033156502625427,0.812113859794264,0.8670513564729575,0.9334025575416065 +0.29003386891587546,0.002583257565877517,0.128343356016126,0.941302293442971,0.39124347787999947,0.5549173319887247,0.0523094288640068,0.07392732927434464,0.038144671304869204,0.5707768906320418,0.45911900889634394,0.6613425009683342,0.24759411985870616,0.31908555157715823,0.11721534069373796 +0.5664316020835152,0.14334773364182296,0.4529295572898214,0.1046958910061575,0.7494496144447053,0.08789050290634126,0.746795744424649,0.9591124916206604,0.1223501447804266,0.6307163040483119,0.39982050599826846,0.9892032960572019,0.5186771981689458,0.5762913360600099,0.5272328682003795 +0.7346780087212786,0.051447703355716246,0.8449316507290252,0.04418023093970436,0.8744012917193588,0.551478879696151,0.1914546183636482,0.6772930171190966,0.3107967304317856,0.47135893371458604,0.5030495004674282,0.5105814795423487,0.23040462851757537,0.20946394444589456,0.34397232932608846 +0.15980578559329262,0.645800371650627,0.493747966754033,0.41152140690252215,0.4279594977912953,0.8509963752455887,0.15749499359118324,0.11218835004422356,0.3666121752387994,0.3591375235539226,0.4638508315516717,0.5099300554121715,0.016464417215011795,0.6627269289769087,0.5112681809851871 +0.779073990645791,0.4034909145553488,0.4774375138263377,0.5364693360954615,0.47411902912267956,0.29412485942284317,0.2985117246367841,0.3930305441634453,0.31864073685611316,0.11431637997626176,0.9447337192997828,0.05837467683321207,0.47897997101741263,0.5216883647649706,0.9523317097755524 +0.5708313136977081,0.5742312134571957,0.6467799212747986,0.26731132999338336,0.8056737477239665,0.7766798565832207,0.4594552702403971,0.26420384174541645,0.7533243665885266,0.16023329741639725,0.006439386329007868,0.23928410420792245,0.2709107787699411,0.36257420433434073,0.22043537858865736 +0.9258046269753054,0.8631371265262906,0.6474672740726625,0.30004985635223236,0.8571535229305768,0.984310741040911,0.6689661764006155,0.694040170137931,0.4876996001989966,0.008091810167080049,0.4299861408970993,0.8733992599893273,0.13061136568760556,0.5850557847815535,0.7766135999214235 +0.4670476593108611,0.1372024937083567,0.7493084740727655,0.41298192320573635,0.47019633646249914,0.7402456268409223,0.3539403902466617,0.4622598186710626,0.90891617425517,0.4541346695071904,0.7501724599076772,0.7935496845388171,0.607718305225681,0.5712793025483863,0.6083405615117734 +0.3579651339033103,0.3931528213665786,0.694308850254078,0.0888443615941219,0.5256240051990227,0.6624479206726026,0.9155075607574713,0.19755433532146505,0.9663940494204895,0.6553927210697587,0.22126275306401544,0.16903787409372562,0.3172952633902072,0.5486777306291436,0.9080486931469579 diff --git a/tests/data/air_quality/data.py b/tests/data/air_quality/data.py new file mode 100644 index 00000000000..d5bad03e65e --- /dev/null +++ b/tests/data/air_quality/data.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd + +columns = [ + 'Date', + 'Time', + 'CO(GT)', + 'PT08.S1(CO)', + 'NMHC(GT)', + 'C6H6(GT)', + 'PT08.S2(NMHC)', + 'NOx(GT)', + 'PT08.S3(NOx)', + 'NO2(GT)', + 'PT08.S4(NO2)', + 'PT08.S5(O3)', + 'T', + 'RH', + 'AH', +] + +nrows = 50 +data = np.random.rand(nrows, len(columns)) + +df = pd.DataFrame(data, columns=columns) + + +df.to_csv('data.csv', index=False) diff --git a/tests/datasets/test_air_quality.py b/tests/datasets/test_air_quality.py new file mode 100644 index 00000000000..18294fe3842 --- /dev/null +++ b/tests/datasets/test_air_quality.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path + +import pytest +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import AirQuality, DatasetNotFoundError + + +class TestAirQuality: + @pytest.fixture() + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> AirQuality: + url = os.path.join('tests', 'data', 'air_quality', 'data.csv') + monkeypatch.setattr(AirQuality, 'url', url) + return AirQuality(tmp_path, download=True) + + def test_getitem(self, dataset: AirQuality) -> None: + item = dataset[0] + x = item['past'] + y = item['future'] + assert isinstance(x, Tensor) + assert x.shape[1] == 15 + assert x.shape[0] == dataset.num_past_steps + assert isinstance(y, Tensor) + assert y.shape[1] == 15 + assert y.shape[0] == dataset.num_future_steps + + def test_len(self, dataset: AirQuality) -> None: + assert len(dataset) == 46 + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + AirQuality(tmp_path) diff --git a/tests/models/test_seq2seq.py b/tests/models/test_seq2seq.py new file mode 100644 index 00000000000..7b1ea34ffb2 --- /dev/null +++ b/tests/models/test_seq2seq.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest +import torch + +from torchgeo.models import LSTMSeq2Seq + +BATCH_SIZE = [1, 2, 7] +INPUT_SIZE_ENCODER = [1, 3] +INPUT_SIZE_DECODER = [2, 3] +OUTPUT_SIZE = [1, 2, 3] +NUM_LAYERS = [1, 2, 3] +HIDDEN_SIZE = [1, 2, 3] + + +class TestLSTMSeq2Seq: + @torch.no_grad() + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('e', INPUT_SIZE_ENCODER) + @pytest.mark.parametrize('d', INPUT_SIZE_DECODER) + def test_input_size(self, b: int, e: int, d: int) -> None: + sequence_length = 3 + output_sequence_length = 3 + n_features = 5 + output_size = 2 + model = LSTMSeq2Seq( + input_size_encoder=e, + input_size_decoder=d, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, e)), + decoder_indices=list(range(0, d)), + output_size=output_size, + output_sequence_len=output_sequence_length, + ) + past_steps = torch.randn(b, sequence_length, n_features) + future_steps = torch.randn(b, output_sequence_length, n_features) + y = model(past_steps, future_steps) + assert y.shape == (b, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('n', NUM_LAYERS) + def test_num_layers(self, n: int) -> None: + batch_size = 5 + input_size_encoder = 3 + input_size_decoder = 2 + sequence_length = 3 + output_sequence_length = 3 + n_features = 5 + output_size = 2 + model = LSTMSeq2Seq( + input_size_encoder=input_size_encoder, + input_size_decoder=input_size_decoder, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, input_size_encoder)), + decoder_indices=list(range(0, input_size_decoder)), + output_size=output_size, + output_sequence_len=output_sequence_length, + num_layers=n, + ) + past_steps = torch.randn(batch_size, sequence_length, n_features) + future_steps = torch.randn(batch_size, output_sequence_length, n_features) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('h', HIDDEN_SIZE) + def test_hidden_size(self, h: int) -> None: + batch_size = 5 + input_size_encoder = 3 + input_size_decoder = 2 + sequence_length = 3 + output_sequence_length = 3 + n_features = 5 + output_size = 2 + model = LSTMSeq2Seq( + input_size_encoder=input_size_encoder, + input_size_decoder=input_size_decoder, + target_indices=list(range(0, output_size)), + encoder_indices=list(range(0, input_size_encoder)), + decoder_indices=list(range(0, input_size_decoder)), + output_size=output_size, + output_sequence_len=output_sequence_length, + hidden_size=h, + ) + past_steps = torch.randn(batch_size, sequence_length, n_features) + future_steps = torch.randn(batch_size, output_sequence_length, n_features) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + def test_none_indices(self) -> None: + batch_size = 5 + sequence_length = 3 + output_sequence_length = 1 + input_size = 5 + output_size = 1 + model = LSTMSeq2Seq( + input_size_encoder=input_size, input_size_decoder=input_size + ) + past_steps = torch.randn(batch_size, sequence_length, input_size) + future_steps = torch.randn(batch_size, output_sequence_length, input_size) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, output_size) + + @torch.no_grad() + @pytest.mark.parametrize('o', OUTPUT_SIZE) + def test_output_size(self, o: int) -> None: + batch_size = 5 + sequence_length = 3 + output_sequence_length = 1 + input_size = 5 + model = LSTMSeq2Seq( + input_size_encoder=input_size, input_size_decoder=input_size, output_size=o + ) + past_steps = torch.randn(batch_size, sequence_length, input_size) + future_steps = torch.randn(batch_size, output_sequence_length, input_size) + y = model(past_steps, future_steps) + assert y.shape == (batch_size, output_sequence_length, o) diff --git a/tests/trainers/test_autoregression.py b/tests/trainers/test_autoregression.py new file mode 100644 index 00000000000..fbd91f791eb --- /dev/null +++ b/tests/trainers/test_autoregression.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +import torch + +from torchgeo.datamodules import MisconfigurationException +from torchgeo.main import main +from torchgeo.trainers import AutoregressionTask + + +class TestAutoregressionTask: + @pytest.mark.parametrize('name', ['air_quality']) + def test_trainer(self, name: str, fast_dev_run: bool) -> None: + config = os.path.join('tests', 'conf', name + '.yaml') + + args = [ + '--config', + config, + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', + str(fast_dev_run), + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', + ] + + main(['fit', *args]) + try: + main(['test', *args]) + except MisconfigurationException: + pass + try: + main(['predict', *args]) + except MisconfigurationException: + pass + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + AutoregressionTask(model='invalid_model') + + def test_invalid_loss(self) -> None: + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + AutoregressionTask(loss='invalid_loss') + + def test_denormalize(self) -> None: + data = torch.rand(1, 3, 1) + mean = data.mean(dim=1, keepdim=True) + std = data.std(dim=1, keepdim=True) + data_normalized = (data - mean) / std + trainer = AutoregressionTask() + denorm = trainer._denormalize(data_normalized, mean, std) + assert torch.equal(data, denorm) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index f6418615404..4fc5ed4992b 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -4,6 +4,7 @@ """TorchGeo datamodules.""" from .agrifieldnet import AgriFieldNetDataModule +from .air_quality import AirQualityDataModule from .bigearthnet import BigEarthNetDataModule from .cabuar import CaBuArDataModule from .caffe import CaFFeDataModule @@ -61,6 +62,7 @@ __all__ = ( 'AgriFieldNetDataModule', + 'AirQualityDataModule', 'BaseDataModule', 'BigEarthNetDataModule', 'COWCCountingDataModule', diff --git a/torchgeo/datamodules/air_quality.py b/torchgeo/datamodules/air_quality.py new file mode 100644 index 00000000000..3a43b578887 --- /dev/null +++ b/torchgeo/datamodules/air_quality.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Air Quality datamodule.""" + +from typing import Any + +from torch import Tensor +from torch.utils.data import Subset + +from ..datasets import AirQuality +from .geo import NonGeoDataModule + + +class AirQualityDataModule(NonGeoDataModule): + """LightningDataModule implementation for the AirQuality dataset. + + Uses the user provided splits to divide the dataset into + train/val/test sets. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + batch_size: int = 64, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new AirQualityDataModule instance. + + Args: + batch_size: Size of each mini-batch. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a testing set. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.AirQuality`. + """ + super().__init__(AirQuality, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + dataset = AirQuality(**self.kwargs) + train_split_pct = 1 - (self.val_split_pct + self.test_split_pct) + train_size = int(train_split_pct * len(dataset)) + val_size = int(self.val_split_pct * len(dataset)) + train_indices = range(train_size) + val_indices = range(train_size, train_size + val_size) + test_indices = range(train_size + val_size, len(dataset)) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) + self.test_dataset = Subset(dataset, test_indices) + + def on_after_batch_transfer( + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: + """Override base class to avoid applying Kornia augmentations to non-image data. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + return batch diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index b1cf39fbfd0..95168b64f57 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -6,6 +6,7 @@ from .advance import ADVANCE from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity from .agrifieldnet import AgriFieldNet +from .air_quality import AirQuality from .airphen import Airphen from .astergdem import AsterGDEM from .benin_cashews import BeninSmallHolderCashews @@ -210,6 +211,7 @@ 'VHR10', 'AbovegroundLiveWoodyBiomassDensity', 'AgriFieldNet', + 'AirQuality', 'Airphen', 'AsterGDEM', 'BeninSmallHolderCashews', diff --git a/torchgeo/datasets/air_quality.py b/torchgeo/datasets/air_quality.py new file mode 100644 index 00000000000..d1101fa93f4 --- /dev/null +++ b/torchgeo/datasets/air_quality.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Air Quality dataset.""" + +import os +from typing import Any + +import pandas as pd +import torch + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path + + +class AirQuality(NonGeoDataset): + """Air Quality dataset. + + The `Air Quality dataset `_ + from the UCI Machine Learning Repository is a multivariate time + series dataset containing air quality measurements from an Italian + city. + + Dataset Format: + + * .csv file containing date, time and air quality measurements + + Dataset Features: + + * hourly averaged sensor responses and reference analyzer ground truth over one year (2004-2005) + * has missing features + + If you use this dataset in your research, please cite: + + * https://doi.org/10.1016/J.SNB.2007.09.060 + + .. versionadded:: 0.7 + """ + + url = 'https://archive.ics.uci.edu/static/public/360/data.csv' + data_file_name = 'data.csv' + + def __init__( + self, + root: Path = 'data', + download: bool = False, + num_past_steps: int = 3, + num_future_steps: int = 1, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + download: if True, download dataset and store it in the root directory + num_past_steps: Number of past time steps to use. + num_future_steps: Number of future time steps to use. + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.download = download + self.num_past_steps = num_past_steps + self.num_future_steps = num_future_steps + self.data = self._load_data() + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.data) - (self.num_past_steps + self.num_future_steps) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data at that index + """ + past_steps = self.data.iloc[index : index + self.num_past_steps] + future_steps = self.data.iloc[ + index + self.num_past_steps : index + + self.num_past_steps + + self.num_future_steps + ] + past_steps = torch.tensor(past_steps.values, dtype=torch.float32) + future_steps = torch.tensor(future_steps.values, dtype=torch.float32) + + mean = past_steps.mean(dim=0, keepdim=True) + std = past_steps.std(dim=0, keepdim=True) + past_steps_normalized = (past_steps - mean) / (std + 1e-12) + future_steps_normalized = (future_steps - mean) / (std + 1e-12) + + return { + 'past': past_steps_normalized, + 'future': future_steps_normalized, + 'mean': mean, + 'std': std, + } + + def _load_data(self) -> pd.DataFrame: + """Load the dataset into a pandas dataframe. + + Returns: + Dataframe containing the data. + """ + # Check if the file already exists + pathname = os.path.join(self.root, self.data_file_name) + if os.path.exists(pathname): + return pd.read_csv(pathname) + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + return pd.read_csv(self.url, na_values=-200) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index a05d6c50bd5..56f3091031c 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -30,6 +30,7 @@ resnet152, ) from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16 +from .seq2seq import LSTMSeq2Seq from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ( ViTBase14_DINOv2_Weights, @@ -63,6 +64,7 @@ 'FCSiamConc', 'FCSiamDiff', 'FarSeg', + 'LSTMSeq2Seq', 'Panopticon', 'Panopticon_Weights', 'ResNet18_Weights', diff --git a/torchgeo/models/seq2seq.py b/torchgeo/models/seq2seq.py new file mode 100644 index 00000000000..b3ccb79ff05 --- /dev/null +++ b/torchgeo/models/seq2seq.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""LSTM Sequence to Sequence (Seq2Seq) Model.""" + +import random + +import torch +import torch.nn as nn +from torch import Tensor + + +class LSTMEncoder(nn.Module): + """Encoder for LSTM Seq2Seq.""" + + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1) -> None: + """Initialize a new LSTMEncoder. + + Args: + input_size: The number of features in the input. + hidden_size: The number of features in the hidden state. + num_layers: The number of LSTM layers. + """ + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass of the encoder. + + Args: + x: Input sequence of shape (b, sequence length, input_size). + + Returns: + Hidden and cell states. + """ + _, (hidden, cell) = self.lstm(x) + return hidden, cell + + +class LSTMDecoder(nn.Module): + """Decoder for LSTM Seq2Seq.""" + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + target_indices: list[int] | None = None, + num_layers: int = 1, + output_sequence_len: int = 1, + teacher_force_prob: float | None = None, + ) -> None: + """Initialize a new LSTMDecoder. + + Args: + input_size: The number of features in the input. + hidden_size: The number of features in the hidden state. + output_size: The number of features output by the decoder. + target_indices: Indices of the target features in the dataset. + If None, uses all features passed to the decoder. Defaults to None. + num_layers: Number of LSTM layers. Defaults to 1. + output_sequence_len: The number of steps to predict forward. Defaults to 1. + teacher_force_prob: Probability of using teacher forcing. If None, does not + use teacher forcing. Defaults to None. + """ + super().__init__() + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + self.fc = nn.Linear(hidden_size, output_size) + self.output_size = output_size + self.target_indices = target_indices + self.output_sequence_len = output_sequence_len + self.teacher_force_prob = teacher_force_prob + + def forward(self, inputs: Tensor, hidden: Tensor, cell: Tensor) -> Tensor: + """Forward pass of the decoder. + + Args: + inputs: Input sequence of shape (b, sequence length, input_size). + hidden: hidden state from the encoder. + cell: cell state from the encoder. + + Returns: + Output sequence of shape (b, output_sequence_len, output_size). + """ + batch_size = inputs.shape[0] + outputs = torch.zeros(batch_size, self.output_sequence_len, self.output_size) + + current_input = inputs[:, 0:1, :] + + for t in range(self.output_sequence_len): + _, (hidden, cell) = self.lstm(current_input, (hidden, cell)) + last_layer_hidden = hidden[-1:] + output = self.fc(last_layer_hidden) + output = output.permute(1, 0, 2) # put batch dimension first + outputs[:, t : t + 1, :] = output + current_input = inputs[:, t : t + 1, :].clone() + teacher_force = ( + random.random() < self.teacher_force_prob + if self.teacher_force_prob is not None + else False + ) + if not teacher_force: + if self.target_indices: + current_input[:, :, self.target_indices] = output + else: + current_input = output + + return outputs + + +class LSTMSeq2Seq(nn.Module): + """LSTM Sequence-to-Sequence (Seq2Seq).""" + + def __init__( + self, + input_size_encoder: int, + input_size_decoder: int, + target_indices: list[int] | None = None, + encoder_indices: list[int] | None = None, + decoder_indices: list[int] | None = None, + hidden_size: int = 1, + output_size: int = 1, + output_sequence_len: int = 1, + num_layers: int = 1, + teacher_force_prob: float | None = None, + ) -> None: + """Initialize a new LSTMSeq2Seq model. + + Args: + input_size_encoder: The number of features in the encoder input. + input_size_decoder: The number of features in the decoder input. + target_indices: The indices of the target(s) in the dataset. If None, uses all features. Defaults to None. + encoder_indices: The indices of the encoder inputs. If None, uses all features. Defaults to None. + decoder_indices: The indices of the decoder inputs. If None, uses all features. Defaults to None. + hidden_size: The number of features in the hidden states of the encoder and decoder. Defaults to 1. + output_size: The number of features output by the model. Defaults to 1. + output_sequence_len: The number of steps to predict forward. Defaults to 1. + num_layers: Number of LSTM layers in the encoder and decoder. Defaults to 1. + teacher_force_prob: Probability of using teacher forcing. If None, does not + use teacher forcing. Defaults to None. + """ + super().__init__() + for indices, size, name in [ + (encoder_indices, input_size_encoder, 'encoder_indices'), + (decoder_indices, input_size_decoder, 'decoder_indices'), + (target_indices, output_size, 'target_indices'), + ]: + if indices: + assert len(indices) == size, f'Length of {name} should match {size}.' + if decoder_indices and isinstance(target_indices, list): + assert set(target_indices).issubset(set(decoder_indices)), ( + 'target_indices should be in decoder_indices.' + ) + # Target indices need to be mapped to the subset of inputs for decoder + target_indices = [ + i for i, val in enumerate(decoder_indices) if val in target_indices + ] + self.encoder = LSTMEncoder(input_size_encoder, hidden_size, num_layers) + self.decoder = LSTMDecoder( + input_size=input_size_decoder, + hidden_size=hidden_size, + output_size=output_size, + target_indices=target_indices, + num_layers=num_layers, + output_sequence_len=output_sequence_len, + teacher_force_prob=teacher_force_prob, + ) + self.encoder_indices = encoder_indices + self.decoder_indices = decoder_indices + + def forward(self, past_steps: Tensor, future_steps: Tensor) -> Tensor: + """Forward pass of the model. + + Args: + past_steps: Past time steps. + future_steps: Future time steps. + + Returns: + Output sequence of shape (b, output_sequence_len, output_size). + """ + if self.encoder_indices: + inputs_encoder = past_steps[:, :, self.encoder_indices] + else: + inputs_encoder = past_steps + inputs_decoder = torch.cat( + [past_steps[:, -1, :].unsqueeze(1), future_steps], dim=1 + ) + if self.decoder_indices: + inputs_decoder = inputs_decoder[:, :, self.decoder_indices] + hidden, cell = self.encoder(inputs_encoder) + outputs: Tensor = self.decoder(inputs_decoder, hidden, cell) + return outputs diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 99c4ad8f4b8..9a6111a97fa 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .autoregression import AutoregressionTask from .base import BaseTask from .byol import BYOLTask from .classification import ClassificationTask, MultiLabelClassificationTask @@ -15,6 +16,7 @@ from .simclr import SimCLRTask __all__ = ( + 'AutoregressionTask', 'BYOLTask', 'BaseTask', 'ClassificationTask', diff --git a/torchgeo/trainers/autoregression.py b/torchgeo/trainers/autoregression.py new file mode 100644 index 00000000000..fbbebb42516 --- /dev/null +++ b/torchgeo/trainers/autoregression.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for autoregression.""" + +from typing import Any + +import torch.nn as nn +from torch import Tensor +from torchmetrics import MetricCollection +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError + +from torchgeo.models import LSTMSeq2Seq + +from .base import BaseTask + + +class AutoregressionTask(BaseTask): + """Autoregression.""" + + def __init__( + self, + model: str = 'lstm_seq2seq', + input_size: int = 1, + input_size_decoder: int = 1, + output_size: int = 1, + loss: str = 'mse', + lr: float = 1e-3, + patience: int = 10, + **kwargs: dict[str, Any], + ) -> None: + """Initialize a new AutoregressionTask instance. + + Args: + model: Name of the model to use, currently supports 'lstm_seq2seq'. + Defaults to 'lstm_seq2seq'. + input_size: The number of features in the input. Defaults to 1. + input_size_decoder: The number of features in the decoder input. + Defaults to 1. + output_size: The number of features output by the model. Defaults to 1. + loss: One of 'mse' or 'mae'. Defaults to 'mse'. + lr: Learning rate for optimizer. Defaults to 1e-3. + patience: Patience for learning rate scheduler. Defaults to 10. + **kwargs: Additional keyword arguments passed to the model. + """ + self.kwargs: dict[str, Any] = kwargs + super().__init__() + + def configure_models(self) -> None: + """Initialize the model.""" + model: str = self.hparams['model'] + input_size = self.hparams['input_size'] + input_size_decoder = self.hparams['input_size_decoder'] + + if model == 'lstm_seq2seq': + self.model = LSTMSeq2Seq( + input_size_encoder=input_size, + input_size_decoder=input_size_decoder, + **self.kwargs, + ) + else: + raise ValueError( + f"Model type '{model}' is not valid. " + "Currently, only supports 'lstm_seq2seq'." + ) + + def configure_losses(self) -> None: + """Initialize the loss criterion. + + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams['loss'] + if loss == 'mse': + self.criterion: nn.Module = nn.MSELoss() + elif loss == 'mae': + self.criterion = nn.L1Loss() + else: + raise ValueError( + f"Loss type '{loss}' is not valid. " + "Currently, supports 'mse' or 'mae' loss." + ) + + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + output_size = self.hparams['output_size'] + metrics = MetricCollection( + { + 'rmse': MeanSquaredError(num_outputs=output_size, squared=False), + 'mae': MeanAbsoluteError(num_outputs=output_size), + } + ) + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') + + def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: + """Compute the loss and additional metrics for the given stage. + + Args: + batch: The output of your DataLoader._ + batch_idx: Integer displaying index of this batch._ + stage: The current stage. + + Returns: + The loss tensor. + """ + target_indices = self.hparams['target_indices'] + past_steps = batch['past'] + future_steps = batch['future'] + y_hat = self(past_steps, future_steps) + if target_indices: + future_steps = future_steps[:, :, target_indices] + loss: Tensor = self.criterion(y_hat, future_steps) + self.log(f'{stage}_loss', loss) + + # Denormalize the data before computing metrics + if all(key in batch for key in ['mean', 'std']): + mean = batch['mean'][:, :, target_indices] + std = batch['std'][:, :, target_indices] + y_hat = self._denormalize(y_hat, mean, std) + future_steps = self._denormalize(future_steps, mean, std) + # Retrieve the correct metrics based on the stage + metrics = getattr(self, f'{stage}_metrics', None) + if metrics: + metrics(y_hat, future_steps) + self.log_dict({f'{k}': v for k, v in metrics.compute().items()}) + + return loss + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + loss = self._shared_step(batch, batch_idx, 'train') + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'val') + + def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ + self._shared_step(batch, batch_idx, 'test') + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted regression values. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + + Returns: + Output predicted values. + """ + target_indices = self.hparams['target_indices'] + past_steps = batch['past'] + future_steps = batch['future'] + y_hat = self(past_steps, future_steps) + mean = batch['mean'][:, :, target_indices] + std = batch['std'][:, :, target_indices] + y_hat_denormalize: Tensor = self._denormalize(y_hat, mean, std) + return y_hat_denormalize + + def _denormalize(self, data: Tensor, mean: Tensor, std: Tensor) -> Tensor: + return data * std + mean