2
2
3
3
import argparse
4
4
from cartopy .crs import PlateCarree
5
+ from cartopy .mpl .gridliner import LatitudeFormatter , LongitudeFormatter
5
6
import matplotlib .pyplot as plt
6
7
from matplotlib .colors import TwoSlopeNorm
8
+ from matplotlib .ticker import AutoMinorLocator
9
+ import numpy as np
7
10
import xarray as xr
8
11
import xstatstests
9
12
@@ -33,6 +36,25 @@ def ks_test(obs_ds, fcst_ds):
33
36
return ks
34
37
35
38
39
+ def _mask_invalid (da ):
40
+ """Get a mask where there is less than one distinct sample."""
41
+
42
+ def count_unique_1d (x ):
43
+ return np .unique (x ).size
44
+
45
+ count = xr .apply_ufunc (
46
+ count_unique_1d ,
47
+ da ,
48
+ input_core_dims = [["sample" ]],
49
+ output_core_dims = [[]],
50
+ vectorize = True ,
51
+ dask = "parallelized" ,
52
+ output_dtypes = ["int" ],
53
+ )
54
+ mask = count > 1
55
+ return mask
56
+
57
+
36
58
def anderson_darling_test (obs_ds , fcst_ds ):
37
59
"""Calculate Anderson Darling test statistic and p-value.
38
60
@@ -45,9 +67,20 @@ def anderson_darling_test(obs_ds, fcst_ds):
45
67
-------
46
68
ad : xarray.Dataset
47
69
Dataset with Anderson Darling statistic and p-value variables
70
+
48
71
"""
72
+ # Temporarily replace non-unique samples with unique integers
73
+ obs_mask = _mask_invalid (obs_ds [list (obs_ds .data_vars )[0 ]])
74
+ fcst_mask = _mask_invalid (fcst_ds [list (fcst_ds .data_vars )[0 ]])
75
+
76
+ x = xr .where (obs_mask , obs_ds , np .arange (obs_ds .sample .size ))
77
+ y = xr .where (fcst_mask , fcst_ds , np .arange (fcst_ds .sample .size ))
78
+
79
+ ad = xstatstests .anderson_ksamp (x , y , dim = "sample" )
80
+
81
+ # Mask dummy results
82
+ ad = xr .where (obs_mask | fcst_mask , ad , np .nan )
49
83
50
- ad = xstatstests .anderson_ksamp (obs_ds , fcst_ds , dim = "sample" )
51
84
ad = ad .rename ({"statistic" : "ad_statistic" })
52
85
ad = ad .rename ({"pvalue" : "ad_pval" })
53
86
@@ -180,21 +213,38 @@ def similarity_spatial_plot(ds, dataset_name=None, outfile=None, alpha=0.05):
180
213
for ax , var in zip (axes .flat , ds .data_vars ):
181
214
if "statistic" in var :
182
215
long_name = ds [var ].attrs ["long_name" ].replace ("_" , " " ).title ()
183
- kwargs = {}
216
+ if ds [var ].min () < 0 :
217
+ kwargs = dict (cmap = plt .cm .coolwarm )
218
+ else :
219
+ kwargs = dict (cmap = plt .cm .viridis )
184
220
elif "pval" in var :
185
221
long_name = f"{ long_name } p-value"
186
222
kwargs = dict (
187
223
cmap = plt .cm .seismic ,
188
224
norm = TwoSlopeNorm (vcenter = alpha , vmin = 0 , vmax = 0.4 ),
189
225
)
226
+ kwargs ["cmap" ].set_bad ("gray" )
190
227
191
228
ds [var ].plot (
192
- ax = ax , transform = PlateCarree (), cbar_kwargs = dict (label = long_name ), ** kwargs
229
+ ax = ax ,
230
+ transform = PlateCarree (),
231
+ cbar_kwargs = dict (label = long_name ),
232
+ robust = True ,
233
+ ** kwargs ,
193
234
)
194
235
ax .coastlines ()
195
236
ax .set_title (long_name )
237
+ ax .set_xlabel (None )
238
+ ax .set_ylabel (None )
239
+ ax .xaxis .set_major_formatter (LongitudeFormatter ())
240
+ ax .yaxis .set_major_formatter (LatitudeFormatter ())
241
+ ax .xaxis .set_minor_locator (AutoMinorLocator ())
242
+ ax .yaxis .set_minor_locator (AutoMinorLocator ())
243
+
244
+ subplotspec = ax .get_subplotspec ()
196
245
ax .xaxis .set_visible (True )
197
- ax .yaxis .set_visible (True )
246
+ if subplotspec .is_first_col ():
247
+ ax .yaxis .set_visible (True )
198
248
199
249
if dataset_name :
200
250
fig .suptitle (dataset_name , y = 1.02 )
0 commit comments