-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Labels
bugSomething isn't workingSomething isn't working
Description
What is your issue?
In most cases, the batch generator will permute the dimension order to agree with the order specified in input_dims
. Here is an example:
>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100, 200)
>>> ds = xr.Dataset(
... {
... "foo": (["time", "y", "x", "z"], np.random.rand(*shape)),
... "bar": (["time", "y", "x", "z"], np.random.randint(0, 10, shape)),
... },
... {
... "x": (["x"], np.arange(shape[-2])),
... "y": (["y"], np.arange(shape[-3])),
... },
... )
>>> print(ds)
<xarray.Dataset>
Dimensions: (time: 10, y: 50, x: 100, z: 200)
Coordinates:
* x (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
* y (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: time, z
Data variables:
foo (time, y, x, z) float64 0.6615 0.04028 0.8633 ... 0.4632 0.6561
bar (time, y, x, z) int64 8 0 9 4 8 9 2 6 7 7 5 ... 3 2 0 7 2 3 2 1 3 6
>>> print(ds['foo'].shape)
(10, 50, 100, 200)
>>> bg = xbatcher.BatchGenerator(ds, input_dims={'x': 10, 'y': 5})
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (y: 5, x: 10, sample: 2000)
Coordinates:
* x (x) int64 0 1 2 3 4 5 6 7 8 9
* y (y) int64 0 1 2 3 4
* sample (sample) object MultiIndex
* time (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 9 9 9 9 9 9 9 9 9 9 9 9
* z (sample) int64 0 1 2 3 4 5 6 7 ... 192 193 194 195 196 197 198 199
Data variables:
foo (sample, x, y) float64 0.6615 0.8259 0.09629 ... 0.2105 0.09571
bar (sample, x, y) int64 8 4 0 6 0 4 4 0 5 4 5 ... 2 3 8 3 4 1 6 1 9 4
>>> print(bg[0]['foo'].shape)
(2000, 10, 5)
In at least one case, the original dimension order is retained:
>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100)
>>> ds = xr.Dataset(
... {
... "foo": (["time", "y", "x"], np.random.rand(*shape)),
... "bar": (["time", "y", "x"], np.random.randint(0, 10, shape)),
... },
... {
... "x": (["x"], np.arange(shape[-1])),
... "y": (["y"], np.arange(shape[-2])),
... },
... )
# Original dimensions permuted
>>> bg = xbatcher.BatchGenerator(
... ds,
... input_dims={"x": 5, "y": 10},
... batch_dims={"time": 2},
... concat_input_dims=True,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (y_input: 10, x_input: 5, sample: 1000)
Coordinates:
x (sample, x_input) int64 0 1 2 3 4 0 1 ... 98 99 95 96 97 98 99
y (sample, y_input) int64 0 1 2 3 4 5 6 ... 43 44 45 46 47 48 49
* sample (sample) object MultiIndex
* input_batch (sample) int64 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
* time (sample) int64 0 1 2 3 4 5 6 7 8 9 0 ... 9 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: y_input, x_input
Data variables:
foo (sample, x_input, y_input) float64 0.3198 0.3109 ... 0.5785
bar (sample, x_input, y_input) int64 1 8 5 6 9 8 7 ... 6 0 9 4 8 5
>>> print(bg[0]['foo'].shape)
(1000, 5, 10)
# Original dimension order retained
>>> bg = xbatcher.BatchGenerator(
... ds,
... input_dims={"x": 5, "y": 10},
... batch_dims={"time": 2},
... concat_input_dims=False,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (time: 10, y: 10, x: 5)
Coordinates:
* x (x) int64 0 1 2 3 4
* y (y) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: time
Data variables:
foo (time, y, x) float64 0.3198 0.5306 0.3465 ... 0.7873 0.5106 0.9177
bar (time, y, x) int64 1 0 2 6 5 8 0 1 2 0 5 ... 1 2 0 2 0 7 5 6 4 8 3
>>> print(bg[0]['foo'].shape)
(10, 10, 5)
We should document the intended behavior for ordering dimensions and test that the shape is consistent. I would have expected that the original dimension would be retained, in contrast to the most common behavior of the batch generator. @jhamman can you provide insight into the original intended behavior?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working