Skip to content

Commit 5289730

Browse files
committed
Update similarity.py plot formatting and fix anderson_ksamp error for grid points without data
1 parent a910895 commit 5289730

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

unseen/similarity.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import argparse
44
from cartopy.crs import PlateCarree
5+
from cartopy.mpl.gridliner import LatitudeFormatter, LongitudeFormatter
56
import matplotlib.pyplot as plt
67
from matplotlib.colors import TwoSlopeNorm
8+
from matplotlib.ticker import AutoMinorLocator
9+
import numpy as np
710
import xarray as xr
811
import xstatstests
912

@@ -33,6 +36,25 @@ def ks_test(obs_ds, fcst_ds):
3336
return ks
3437

3538

39+
def _mask_invalid(da):
40+
"""Get a mask where there is less than one distinct sample."""
41+
42+
def count_unique_1d(x):
43+
return np.unique(x).size
44+
45+
count = xr.apply_ufunc(
46+
count_unique_1d,
47+
da,
48+
input_core_dims=[["sample"]],
49+
output_core_dims=[[]],
50+
vectorize=True,
51+
dask="parallelized",
52+
output_dtypes=["int"],
53+
)
54+
mask = count > 1
55+
return mask
56+
57+
3658
def anderson_darling_test(obs_ds, fcst_ds):
3759
"""Calculate Anderson Darling test statistic and p-value.
3860
@@ -45,9 +67,20 @@ def anderson_darling_test(obs_ds, fcst_ds):
4567
-------
4668
ad : xarray.Dataset
4769
Dataset with Anderson Darling statistic and p-value variables
70+
4871
"""
72+
# Temporarily replace non-unique samples with unique integers
73+
obs_mask = _mask_invalid(obs_ds[list(obs_ds.data_vars)[0]])
74+
fcst_mask = _mask_invalid(fcst_ds[list(fcst_ds.data_vars)[0]])
75+
76+
x = xr.where(obs_mask, obs_ds, np.arange(obs_ds.sample.size))
77+
y = xr.where(fcst_mask, fcst_ds, np.arange(fcst_ds.sample.size))
78+
79+
ad = xstatstests.anderson_ksamp(x, y, dim="sample")
80+
81+
# Mask dummy results
82+
ad = xr.where(obs_mask | fcst_mask, ad, np.nan)
4983

50-
ad = xstatstests.anderson_ksamp(obs_ds, fcst_ds, dim="sample")
5184
ad = ad.rename({"statistic": "ad_statistic"})
5285
ad = ad.rename({"pvalue": "ad_pval"})
5386

@@ -180,21 +213,38 @@ def similarity_spatial_plot(ds, dataset_name=None, outfile=None, alpha=0.05):
180213
for ax, var in zip(axes.flat, ds.data_vars):
181214
if "statistic" in var:
182215
long_name = ds[var].attrs["long_name"].replace("_", " ").title()
183-
kwargs = {}
216+
if ds[var].min() < 0:
217+
kwargs = dict(cmap=plt.cm.coolwarm)
218+
else:
219+
kwargs = dict(cmap=plt.cm.viridis)
184220
elif "pval" in var:
185221
long_name = f"{long_name} p-value"
186222
kwargs = dict(
187223
cmap=plt.cm.seismic,
188224
norm=TwoSlopeNorm(vcenter=alpha, vmin=0, vmax=0.4),
189225
)
226+
kwargs["cmap"].set_bad("gray")
190227

191228
ds[var].plot(
192-
ax=ax, transform=PlateCarree(), cbar_kwargs=dict(label=long_name), **kwargs
229+
ax=ax,
230+
transform=PlateCarree(),
231+
cbar_kwargs=dict(label=long_name),
232+
robust=True,
233+
**kwargs,
193234
)
194235
ax.coastlines()
195236
ax.set_title(long_name)
237+
ax.set_xlabel(None)
238+
ax.set_ylabel(None)
239+
ax.xaxis.set_major_formatter(LongitudeFormatter())
240+
ax.yaxis.set_major_formatter(LatitudeFormatter())
241+
ax.xaxis.set_minor_locator(AutoMinorLocator())
242+
ax.yaxis.set_minor_locator(AutoMinorLocator())
243+
244+
subplotspec = ax.get_subplotspec()
196245
ax.xaxis.set_visible(True)
197-
ax.yaxis.set_visible(True)
246+
if subplotspec.is_first_col():
247+
ax.yaxis.set_visible(True)
198248

199249
if dataset_name:
200250
fig.suptitle(dataset_name, y=1.02)

0 commit comments

Comments
 (0)