Skip to content

Commit 248be12

Browse files
authored
Merge pull request #3093 from TAlonglong/issue3092
Fix 3D masks with size 1 dimension in MaskingCompositor
2 parents 8064ddf + 3a208c0 commit 248be12

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

satpy/composites/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,9 @@ def __call__(self, projectables, *args, **kwargs):
19081908
data_in = projectables[0]
19091909
mask_in = projectables[1]
19101910

1911+
# remove "bands" dimension for single band masks (ex. "L")
1912+
mask_in = mask_in.squeeze(drop=True)
1913+
19111914
alpha_attrs = data_in.attrs.copy()
19121915
data = self._select_data_bands(data_in)
19131916

satpy/tests/test_composites.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,16 @@ def conditions_v2(self):
16361636
"value": 2,
16371637
"transparency": 50}]
16381638

1639+
@pytest.fixture
1640+
def conditions_v3(self):
1641+
"""Masking conditions with other numerical values."""
1642+
return [{"method": "equal",
1643+
"value": 0,
1644+
"transparency": 100},
1645+
{"method": "equal",
1646+
"value": 1,
1647+
"transparency": 0}]
1648+
16391649
@pytest.fixture
16401650
def test_data(self):
16411651
"""Test data to use with masking compositors."""
@@ -1654,6 +1664,30 @@ def test_ct_data(self):
16541664
ct_data.attrs["flag_values"] = flag_values
16551665
return ct_data
16561666

1667+
@pytest.fixture
1668+
def value_3d_data(self):
1669+
"""Test 3D data array."""
1670+
value_3d_data = da.array([[[1, 0, 0],
1671+
[0, 1, 0],
1672+
[0, 0, 1]]])
1673+
value_3d_data = xr.DataArray(value_3d_data, dims=["bands", "y", "x"])
1674+
return value_3d_data
1675+
1676+
@pytest.fixture
1677+
def value_3d_data_bands(self):
1678+
"""Test 3D data array."""
1679+
value_3d_data = da.array([[[1, 0, 0],
1680+
[0, 1, 0],
1681+
[0, 0, 1]],
1682+
[[1, 0, 0],
1683+
[0, 1, 0],
1684+
[0, 0, 1]],
1685+
[[1, 0, 0],
1686+
[0, 1, 0],
1687+
[0, 0, 1]]])
1688+
value_3d_data = xr.DataArray(value_3d_data, dims=["bands", "y", "x"])
1689+
return value_3d_data
1690+
16571691
@pytest.fixture
16581692
def test_ct_data_v3(self, test_ct_data):
16591693
"""Set ct data to NaN where it originally is 1."""
@@ -1736,6 +1770,44 @@ def test_call_numerical_transparency_data(
17361770
np.testing.assert_allclose(res.sel(bands=m), reference_data)
17371771
np.testing.assert_allclose(res.sel(bands="A"), reference_alpha)
17381772

1773+
@pytest.mark.parametrize("mode", ["LA", "RGBA"])
1774+
def test_call_numerical_transparency_data_with_3d_mask_data(
1775+
self, test_data, value_3d_data, conditions_v3, mode):
1776+
"""Test call the compositor with numerical transparency data.
1777+
1778+
Use parameterisation to test different image modes.
1779+
"""
1780+
from satpy.composites import MaskingCompositor
1781+
1782+
reference_data_v3 = test_data.where(value_3d_data[0] > 0)
1783+
reference_alpha_v3 = xr.DataArray([[1., 0., 0.],
1784+
[0., 1., 0.],
1785+
[0., 0., 1.]])
1786+
1787+
# Test with numerical transparency data using 3d test mask data which can be squeezed
1788+
comp = MaskingCompositor("name", conditions=conditions_v3,
1789+
mode=mode)
1790+
res = comp([test_data, value_3d_data])
1791+
assert res.mode == mode
1792+
for m in mode.rstrip("A"):
1793+
np.testing.assert_allclose(res.sel(bands=m), reference_data_v3)
1794+
np.testing.assert_allclose(res.sel(bands="A"), reference_alpha_v3)
1795+
1796+
@pytest.mark.parametrize("mode", ["LA", "RGBA"])
1797+
def test_call_numerical_transparency_data_with_3d_mask_data_exception(
1798+
self, test_data, value_3d_data_bands, conditions_v3, mode):
1799+
"""Test call the compositor with numerical transparency data, too many dimensions to squeeze.
1800+
1801+
Use parameterisation to test different image modes.
1802+
"""
1803+
from satpy.composites import MaskingCompositor
1804+
1805+
# Test with numerical transparency data using 3d test mask data which can not be squeezed
1806+
comp = MaskingCompositor("name", conditions=conditions_v3,
1807+
mode=mode)
1808+
with pytest.raises(ValueError, match=".*Received 3 dimension\\(s\\) but expected 2.*"):
1809+
comp([test_data, value_3d_data_bands])
1810+
17391811
def test_call_named_fields(self, conditions_v2, test_data, test_ct_data,
17401812
reference_data, reference_alpha):
17411813
"""Test with named fields."""

0 commit comments

Comments
 (0)