Skip to content

Commit 2dbb039

Browse files
authored
EuroSAT: redistribute split files on Hugging Face (#2432)
1 parent 04cfff1 commit 2dbb039

File tree

9 files changed

+34
-51
lines changed

9 files changed

+34
-51
lines changed

tests/data/eurosat/EuroSAT100.zip

180 KB
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AnnualCrop_1.tif
2+
Forest_1.tif

tests/datasets/test_eurosat.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,10 @@ def dataset(
3232
) -> EuroSAT:
3333
base_class: type[EuroSAT] = request.param[0]
3434
split: str = request.param[1]
35-
md5 = 'aa051207b0547daba0ac6af57808d68e'
36-
monkeypatch.setattr(base_class, 'md5', md5)
37-
url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip')
35+
url = os.path.join('tests', 'data', 'eurosat') + os.sep
3836
monkeypatch.setattr(base_class, 'url', url)
39-
monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip')
40-
monkeypatch.setattr(
41-
base_class,
42-
'split_urls',
43-
{
44-
'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'),
45-
'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'),
46-
'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'),
47-
},
48-
)
49-
monkeypatch.setattr(
50-
base_class,
51-
'split_md5s',
52-
{
53-
'train': '4af60a00fdfdf8500572ae5360694b71',
54-
'val': '4af60a00fdfdf8500572ae5360694b71',
55-
'test': '4af60a00fdfdf8500572ae5360694b71',
56-
},
57-
)
58-
root = tmp_path
5937
transforms = nn.Identity()
60-
return base_class(
61-
root=root, split=split, transforms=transforms, download=True, checksum=True
62-
)
38+
return base_class(tmp_path, split=split, transforms=transforms, download=True)
6339

6440
def test_getitem(self, dataset: EuroSAT) -> None:
6541
x = dataset[0]
@@ -84,14 +60,14 @@ def test_add(self, dataset: EuroSAT) -> None:
8460
assert len(ds) == 4
8561

8662
def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
87-
EuroSAT(root=tmp_path, download=True)
63+
type(dataset)(tmp_path)
8864

8965
def test_already_downloaded_not_extracted(
9066
self, dataset: EuroSAT, tmp_path: Path
9167
) -> None:
9268
shutil.rmtree(dataset.root)
93-
shutil.copy(dataset.url, tmp_path)
94-
EuroSAT(root=tmp_path, download=False)
69+
shutil.copy(dataset.url + dataset.filename, tmp_path)
70+
type(dataset)(tmp_path)
9571

9672
def test_not_downloaded(self, tmp_path: Path) -> None:
9773
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
@@ -108,7 +84,7 @@ def test_plot(self, dataset: EuroSAT) -> None:
10884
plt.close()
10985

11086
def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
111-
dataset = EuroSAT(root=tmp_path, bands=('B03',))
87+
dataset = type(dataset)(tmp_path, bands=('B03',))
11288
with pytest.raises(
11389
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
11490
):

torchgeo/datasets/eurosat.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
5454
* https://ieeexplore.ieee.org/document/8519248
5555
"""
5656

57-
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
57+
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/'
5858
filename = 'EuroSATallBands.zip'
5959
md5 = '5ac12b3b2557aa56e1826e981e8e200e'
6060

@@ -64,10 +64,10 @@ class EuroSAT(NonGeoClassificationDataset):
6464
)
6565

6666
splits = ('train', 'val', 'test')
67-
split_urls: ClassVar[dict[str, str]] = {
68-
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
69-
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
70-
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
67+
split_filenames: ClassVar[dict[str, str]] = {
68+
'train': 'eurosat-train.txt',
69+
'val': 'eurosat-val.txt',
70+
'test': 'eurosat-test.txt',
7171
}
7272
split_md5s: ClassVar[dict[str, str]] = {
7373
'train': '908f142e73d6acdf3f482c5e80d851b1',
@@ -141,7 +141,7 @@ def __init__(
141141
self._verify()
142142

143143
valid_fns = set()
144-
with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f:
144+
with open(os.path.join(self.root, self.split_filenames[split])) as f:
145145
for fn in f:
146146
valid_fns.add(fn.strip().replace('.jpg', '.tif'))
147147

@@ -207,16 +207,12 @@ def _verify(self) -> None:
207207
def _download(self) -> None:
208208
"""Download the dataset."""
209209
download_url(
210-
self.url,
211-
self.root,
212-
filename=self.filename,
213-
md5=self.md5 if self.checksum else None,
210+
self.url + self.filename, self.root, md5=self.md5 if self.checksum else None
214211
)
215212
for split in self.splits:
216213
download_url(
217-
self.split_urls[split],
214+
self.url + self.split_filenames[split],
218215
self.root,
219-
filename=f'eurosat-{split}.txt',
220216
md5=self.split_md5s[split] if self.checksum else None,
221217
)
222218

@@ -305,10 +301,10 @@ class EuroSATSpatial(EuroSAT):
305301
.. versionadded:: 0.6
306302
"""
307303

308-
split_urls: ClassVar[dict[str, str]] = {
309-
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
310-
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
311-
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
304+
split_filenames: ClassVar[dict[str, str]] = {
305+
'train': 'eurosat-spatial-train.txt',
306+
'val': 'eurosat-spatial-val.txt',
307+
'test': 'eurosat-spatial-test.txt',
312308
}
313309
split_md5s: ClassVar[dict[str, str]] = {
314310
'train': '7be3254be39f23ce4d4d144290c93292',
@@ -328,14 +324,13 @@ class EuroSAT100(EuroSAT):
328324
.. versionadded:: 0.5
329325
"""
330326

331-
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
332327
filename = 'EuroSAT100.zip'
333328
md5 = 'c21c649ba747e86eda813407ef17d596'
334329

335-
split_urls: ClassVar[dict[str, str]] = {
336-
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
337-
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
338-
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
330+
split_filenames: ClassVar[dict[str, str]] = {
331+
'train': 'eurosat-100-train.txt',
332+
'val': 'eurosat-100-val.txt',
333+
'test': 'eurosat-100-test.txt',
339334
}
340335
split_md5s: ClassVar[dict[str, str]] = {
341336
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',

0 commit comments

Comments
 (0)