Skip to content

Dimension order should be set by input_dims #126

@maxrjones

Description

@maxrjones

What is your issue?

In most cases, the batch generator will permute the dimension order to agree with the order specified in input_dims. Here is an example:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher

>>> shape = (10, 50, 100, 200)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x", "z"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x", "z"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-2])),
...         "y": (["y"], np.arange(shape[-3])),
...     },
... )
>>> print(ds)
<xarray.Dataset>
Dimensions:  (time: 10, y: 50, x: 100, z: 200)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: time, z
Data variables:
    foo      (time, y, x, z) float64 0.6615 0.04028 0.8633 ... 0.4632 0.6561
    bar      (time, y, x, z) int64 8 0 9 4 8 9 2 6 7 7 5 ... 3 2 0 7 2 3 2 1 3 6
>>> print(ds['foo'].shape)
(10, 50, 100, 200)
>>> bg = xbatcher.BatchGenerator(ds, input_dims={'x': 10, 'y': 5})
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (y: 5, x: 10, sample: 2000)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 0 1 2 3 4
  * sample   (sample) object MultiIndex
  * time     (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 9 9 9 9 9 9 9 9 9 9 9 9
  * z        (sample) int64 0 1 2 3 4 5 6 7 ... 192 193 194 195 196 197 198 199
Data variables:
    foo      (sample, x, y) float64 0.6615 0.8259 0.09629 ... 0.2105 0.09571
    bar      (sample, x, y) int64 8 4 0 6 0 4 4 0 5 4 5 ... 2 3 8 3 4 1 6 1 9 4
>>> print(bg[0]['foo'].shape)
(2000, 10, 5)

In at least one case, the original dimension order is retained:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-1])),
...         "y": (["y"], np.arange(shape[-2])),
...     },
... )
# Original dimensions permuted
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=True,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:      (y_input: 10, x_input: 5, sample: 1000)
Coordinates:
    x            (sample, x_input) int64 0 1 2 3 4 0 1 ... 98 99 95 96 97 98 99
    y            (sample, y_input) int64 0 1 2 3 4 5 6 ... 43 44 45 46 47 48 49
  * sample       (sample) object MultiIndex
  * input_batch  (sample) int64 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * time         (sample) int64 0 1 2 3 4 5 6 7 8 9 0 ... 9 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: y_input, x_input
Data variables:
    foo          (sample, x_input, y_input) float64 0.3198 0.3109 ... 0.5785
    bar          (sample, x_input, y_input) int64 1 8 5 6 9 8 7 ... 6 0 9 4 8 5
>>> print(bg[0]['foo'].shape)
(1000, 5, 10)
# Original dimension order retained
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=False,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (time: 10, y: 10, x: 5)
Coordinates:
  * x        (x) int64 0 1 2 3 4
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: time
Data variables:
    foo      (time, y, x) float64 0.3198 0.5306 0.3465 ... 0.7873 0.5106 0.9177
    bar      (time, y, x) int64 1 0 2 6 5 8 0 1 2 0 5 ... 1 2 0 2 0 7 5 6 4 8 3
>>> print(bg[0]['foo'].shape)
(10, 10, 5)

We should document the intended behavior for ordering dimensions and test that the shape is consistent. I would have expected that the original dimension would be retained, in contrast to the most common behavior of the batch generator. @jhamman can you provide insight into the original intended behavior?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions