Skip to content

Commit 1675794

Browse files
committed
Add toy dataset
1 parent d439b80 commit 1675794

File tree

1 file changed

+52
-52
lines changed

1 file changed

+52
-52
lines changed

docs/tutorials/torchgeo.ipynb

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@
7070
"source": [
7171
"import os\n",
7272
"import tempfile\n",
73+
"from datetime import datetime\n",
7374
"\n",
75+
"from matplotlib import pyplot as plt\n",
7476
"from torch.utils.data import DataLoader\n",
7577
"\n",
7678
"from torchgeo.datasets import CDL, BoundingBox, Landsat7, Landsat8, stack_samples\n",
77-
"from torchgeo.datasets.utils import download_url\n",
79+
"from torchgeo.datasets.utils import download_and_extract_archive\n",
7880
"from torchgeo.samplers import GridGeoSampler, RandomGeoSampler"
7981
]
8082
},
@@ -102,7 +104,7 @@
102104
"\n",
103105
"Traditionally, people either performed classification on a single pixel at a time or curated their own benchmark dataset. This works fine for training, but isn't really useful for inference. What we would really like to be able to do is sample small pixel-aligned pairs of input images and output masks from the region of overlap between both datasets. This exact situation is illustrated in the following figure:\n",
104106
"\n",
105-
"![Landsat CDL intersection]()\n",
107+
"![Landsat CDL intersection](https://github.com/microsoft/torchgeo/blob/main/images/geodataset.png?raw=true)\n",
106108
"\n",
107109
"Now, let's see what features TorchGeo has to support this kind of use case."
108110
]
@@ -141,18 +143,24 @@
141143
"source": [
142144
"landsat_root = os.path.join(tempfile.gettempdir(), 'landsat')\n",
143145
"\n",
144-
"download_url()\n",
145-
"download_url()\n",
146+
"url = 'https://hf.co/datasets/torchgeo/tutorials/resolve/ff30b729e3cbf906148d69a4441cc68023898924/'\n",
147+
"landsat7_url = url + 'LE07_L2SP_022032_20230725_20230820_02_T1.tar.gz'\n",
148+
"landsat8_url = url + 'LC08_L2SP_023032_20230831_20230911_02_T1.tar.gz'\n",
146149
"\n",
147-
"landsat7 = Landsat7(\n",
148-
" paths=landsat_root, bands=['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']\n",
149-
")\n",
150-
"landsat8 = Landsat8(\n",
151-
" paths=landsat_root, bands=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8']\n",
152-
")\n",
150+
"download_and_extract_archive(landsat7_url, landsat_root)\n",
151+
"download_and_extract_archive(landsat8_url, landsat_root)\n",
152+
"\n",
153+
"landsat7_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']\n",
154+
"landsat8_bands = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']\n",
155+
"\n",
156+
"landsat7 = Landsat7(paths=landsat_root, bands=landsat7_bands)\n",
157+
"landsat8 = Landsat8(paths=landsat_root, bands=landsat8_bands)\n",
153158
"\n",
154159
"print(landsat7)\n",
155-
"print(landsat8)"
160+
"print(landsat8)\n",
161+
"\n",
162+
"print(landsat7.crs)\n",
163+
"print(landsat8.crs)"
156164
]
157165
},
158166
{
@@ -186,11 +194,14 @@
186194
"source": [
187195
"cdl_root = os.path.join(tempfile.gettempdir(), 'cdl')\n",
188196
"\n",
189-
"download_url()\n",
197+
"cdl_url = url + '2023_30m_cdls.zip'\n",
198+
"\n",
199+
"download_and_extract_archive(cdl_url, cdl_root)\n",
190200
"\n",
191201
"cdl = CDL(paths=cdl_root)\n",
192202
"\n",
193-
"print(cdl)"
203+
"print(cdl)\n",
204+
"print(cdl.crs)"
194205
]
195206
},
196207
{
@@ -201,8 +212,8 @@
201212
"Again, the following details are worth noting:\n",
202213
"\n",
203214
"* We could actually ask the `CDL` dataset to download our data for us by adding `download=True`\n",
204-
"* All three datasets have different spatial extends\n",
205-
"* All three datasets have different CRSs"
215+
"* All datasets have different spatial extents\n",
216+
"* All datasets have different CRSs"
206217
]
207218
},
208219
{
@@ -223,7 +234,8 @@
223234
"outputs": [],
224235
"source": [
225236
"landsat = landsat7 | landsat8\n",
226-
"print(landsat)"
237+
"print(landsat)\n",
238+
"print(landsat.crs)"
227239
]
228240
},
229241
{
@@ -242,7 +254,8 @@
242254
"outputs": [],
243255
"source": [
244256
"dataset = landsat & cdl\n",
245-
"print(dataset)"
257+
"print(dataset)\n",
258+
"print(dataset.crs)"
246259
]
247260
},
248261
{
@@ -262,7 +275,7 @@
262275
"\n",
263276
"How did we do this? TorchGeo uses a data structure called an *R-tree* to store the spatiotemporal bounding box of every file in the dataset. \n",
264277
"\n",
265-
"![R-tree]()\n",
278+
"![R-tree](https://raw.githubusercontent.com/davidmoten/davidmoten.github.io/master/resources/rtree-3d/plot2.png)\n",
266279
"\n",
267280
"TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
268281
]
@@ -274,11 +287,21 @@
274287
"metadata": {},
275288
"outputs": [],
276289
"source": [
277-
"bbox = BoundingBox()\n",
278-
"sample = dataset[sample]\n",
290+
"size = 256\n",
279291
"\n",
280-
"landsat.plot(sample)\n",
281-
"cdl.plot(sample)"
292+
"xmin = 925000\n",
293+
"xmax = xmin + size * 30\n",
294+
"ymin = 4470000\n",
295+
"ymax = ymin + size * 30\n",
296+
"tmin = datetime(2023, 1, 1).timestamp()\n",
297+
"tmax = datetime(2023, 12, 31).timestamp()\n",
298+
"\n",
299+
"bbox = BoundingBox(xmin, xmax, ymin, ymax, tmin, tmax)\n",
300+
"sample = dataset[bbox]\n",
301+
"\n",
302+
"landsat8.plot(sample)\n",
303+
"cdl.plot(sample)\n",
304+
"plt.show()"
282305
]
283306
},
284307
{
@@ -289,15 +312,6 @@
289312
"TorchGeo uses *windowed-reading* to only read the blocks of memory needed to load a small patch from a large raster tile. It also automatically reprojects all data to the same CRS and resolution (from the first dataset). This can be controlled by explicitly passing `crs` or `res` to the dataset."
290313
]
291314
},
292-
{
293-
"cell_type": "markdown",
294-
"id": "02368e20-3391-4be7-bbe5-5a3c367ab398",
295-
"metadata": {},
296-
"source": [
297-
"### Geospatial splitting\n",
298-
"\n"
299-
]
300-
},
301315
{
302316
"cell_type": "markdown",
303317
"id": "e2e4221e-dfb7-4966-96a6-e52400ae266c",
@@ -327,8 +341,8 @@
327341
"metadata": {},
328342
"outputs": [],
329343
"source": [
330-
"train_sampler = RandomGeoSampler(dataset, size=256, length=1000)\n",
331-
"print(next(train_sampler))"
344+
"train_sampler = RandomGeoSampler(dataset, size=size, length=1000)\n",
345+
"next(iter(train_sampler))"
332346
]
333347
},
334348
{
@@ -338,7 +352,7 @@
338352
"source": [
339353
"### Gridded sampling\n",
340354
"\n",
341-
"At evaluation time, this actually becomes a problem. We want to make sure we aren't making multiple predictions for the same location. We also want to make sure we don't miss any locations. To achieve this, TorchGeo also provides a `GridGeoSampler`. We can tell the sampler the size of each image patch and the stride of our sliding window (defaults to patch size)."
355+
"At evaluation time, this actually becomes a problem. We want to make sure we aren't making multiple predictions for the same location. We also want to make sure we don't miss any locations. To achieve this, TorchGeo also provides a `GridGeoSampler`. We can tell the sampler the size of each image patch and the stride of our sliding window."
342356
]
343357
},
344358
{
@@ -348,8 +362,8 @@
348362
"metadata": {},
349363
"outputs": [],
350364
"source": [
351-
"test_sampler = GridGeoSampler(dataset, size=256)\n",
352-
"print(next(test_sampler))"
365+
"test_sampler = GridGeoSampler(dataset, size=size, stride=size)\n",
366+
"next(iter(test_sampler))"
353367
]
354368
},
355369
{
@@ -379,16 +393,10 @@
379393
},
380394
{
381395
"cell_type": "markdown",
382-
"id": "3518c7d9-1bb3-4bc2-8216-53044d0b4009",
396+
"id": "e46e8453-df25-4265-a85b-75dce7dea047",
383397
"metadata": {},
384398
"source": [
385-
"\n",
386-
"* Transforms?\n",
387-
"* Models\n",
388-
" * U-Net + pre-trained ResNet\n",
389-
" * Model pre-trained directly on satellite imagery\n",
390-
"* Training and evaluation\n",
391-
" * Copy everything else from "
399+
"Now that we have working data loaders, we can copy-n-paste our training code from the Introduction to PyTorch tutorial. We only need to change our model to one designed for semantic segmentation, such as a U-Net. Every other line of code would be identical to how you would do this in your normal PyTorch workflow."
392400
]
393401
},
394402
{
@@ -403,14 +411,6 @@
403411
"* [TorchGeo: Deep Learning With Geospatial Data](https://arxiv.org/abs/2111.08872)\n",
404412
"* [Geospatial deep learning with TorchGeo](https://pytorch.org/blog/geospatial-deep-learning-with-torchgeo/)"
405413
]
406-
},
407-
{
408-
"cell_type": "code",
409-
"execution_count": null,
410-
"id": "38e60635-69b2-47c9-8df2-fd7c872abdd9",
411-
"metadata": {},
412-
"outputs": [],
413-
"source": []
414414
}
415415
],
416416
"metadata": {

0 commit comments

Comments
 (0)