Skip to content

Commit b6d745a

Browse files
committed
Update eva.py
Add empirical return level/period functions, add GEV parameter spatial plot, update fit_gev for fixed params, check_gev_relative_fit and add option to drop the maximum value before fitting GEV parameters
1 parent 5289730 commit b6d745a

File tree

2 files changed

+219
-23
lines changed

2 files changed

+219
-23
lines changed

unseen/eva.py

Lines changed: 215 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Extreme value analysis functions."""
22

33
import argparse
4+
from cartopy.crs import PlateCarree
5+
from cartopy.mpl.gridliner import LatitudeFormatter, LongitudeFormatter
46
from lmoments3 import distr
57
import matplotlib.pyplot as plt
68
from matplotlib import colormaps
79
from matplotlib.dates import date2num
810
from matplotlib.ticker import AutoMinorLocator
911
import numpy as np
1012
from scipy.optimize import minimize
11-
from scipy.stats import genextreme, ks_1samp, cramervonmises
13+
from scipy.stats import genextreme, ecdf, ks_1samp, cramervonmises
1214
from scipy.stats.distributions import chi2
1315
import warnings
1416
from xarray import apply_ufunc, DataArray
@@ -64,11 +66,14 @@ def fit_gev(
6466
fitstart="LMM",
6567
loc1=0,
6668
scale1=0,
67-
retry_fit=True,
69+
fshape=None,
70+
floc=None,
71+
fscale=None,
72+
retry_fit=False,
6873
assert_good_fit=False,
6974
pick_best_model=False,
7075
alpha=0.05,
71-
method="Nelder-Mead",
76+
optimizer="Nelder-Mead",
7277
goodness_of_fit_kwargs=dict(test="ks"),
7378
):
7479
"""Estimate stationary or nonstationary GEV distribution parameters.
@@ -100,7 +105,7 @@ def fit_gev(
100105
Mutually exclusive with `stationary` and/or `assert_good_fit`.
101106
alpha : float, default 0.05
102107
Fit test p-value threshold for stationary fit (relative/goodness of fit)
103-
method : {'Nelder-Mead', 'L-BFGS-B', 'TNC', 'SLSQP', 'Powell',
108+
optimizer : {'Nelder-Mead', 'L-BFGS-B', 'TNC', 'SLSQP', 'Powell',
104109
'trust-constr', 'COBYLA'}, default 'Nelder-Mead'
105110
Optimization method for nonstationary fit
106111
goodness_of_fit_kwargs : dict, optional
@@ -145,11 +150,14 @@ def _fit_1d(
145150
fitstart,
146151
loc1,
147152
scale1,
153+
fshape,
154+
floc,
155+
fscale,
148156
retry_fit,
149157
assert_good_fit,
150158
pick_best_model,
151159
alpha,
152-
method,
160+
optimizer,
153161
goodness_of_fit_kwargs,
154162
):
155163
"""Estimate distribution parameters."""
@@ -210,23 +218,38 @@ def _fit_1d(
210218
if not stationary:
211219
# Temporarily reverse shape sign (scipy uses different sign convention)
212220
dparams_ns_i = [-dparams_i[0], dparams_i[1], loc1, dparams_i[2], scale1]
213-
221+
dof = [3, 5]
222+
for fixed in [fshape, floc, fscale]:
223+
if fixed is not None:
224+
dof[0] -= 1
214225
# Optimisation bounds (scale parameter must be non-negative)
226+
215227
bounds = [(None, None)] * 5
216-
bounds[3] = (0, None) # Positive scale parameter
228+
bounds[0] = (fshape, fshape)
229+
bounds[1] = (floc, floc)
230+
231+
if fscale is None:
232+
# Allow positive scale parameter (not fixed)
233+
bounds[3] = (0, None)
234+
else:
235+
bounds[3] = (fscale, fscale)
236+
237+
# Trend parameters
217238
if loc1 is None:
218239
dparams_ns_i[2] = 0
219240
bounds[2] = (0, 0) # Only allow trend in scale
241+
dof[1] -= 1
220242
if scale1 is None:
221243
dparams_ns_i[4] = 0
222244
bounds[4] = (0, 0) # Only allow trend in location
245+
dof[1] -= 1
223246

224247
# Minimise the negative log-likelihood function to get optimal dparams
225248
res = minimize(
226249
_gev_nllf,
227250
dparams_ns_i,
228251
args=(data, covariate),
229-
method=method,
252+
method=optimizer,
230253
bounds=bounds,
231254
)
232255
dparams_ns = np.array([i for i in res.x], dtype="float64")
@@ -236,7 +259,13 @@ def _fit_1d(
236259
# Stationary and nonstationary model relative goodness of fit
237260
if pick_best_model:
238261
dparams = get_best_GEV_model_1d(
239-
data, dparams, dparams_ns, covariate, alpha, test=pick_best_model
262+
data,
263+
dparams,
264+
dparams_ns,
265+
covariate,
266+
alpha,
267+
test=pick_best_model,
268+
dof=dof,
240269
)
241270
else:
242271
dparams = dparams_ns
@@ -554,7 +583,7 @@ def _fit_test_genextreme(data, dparams, **kwargs):
554583
return pvalue
555584

556585

557-
def check_gev_relative_fit(data, L1, L2, test, alpha=0.05):
586+
def check_gev_relative_fit(data, L1, L2, test, alpha=0.05, dof=[3, 5]):
558587
"""Test relative fit of stationary and nonstationary GEV distribution.
559588
560589
Parameters
@@ -579,8 +608,7 @@ def check_gev_relative_fit(data, L1, L2, test, alpha=0.05):
579608
Hydrology, 547, 557-574. https://doi.org/10.1016/j.jhydrol.2017.02.005
580609
"""
581610

582-
dof = [3, 5] # Degrees of freedom of each model
583-
611+
# Degrees of freedom of each model
584612
if test.casefold() == "lrt":
585613
# Likelihood ratio test statistic
586614
LR = -2 * (L2 - L1)
@@ -600,7 +628,9 @@ def check_gev_relative_fit(data, L1, L2, test, alpha=0.05):
600628
return result
601629

602630

603-
def get_best_GEV_model_1d(data, dparams, dparams_ns, covariate, alpha, test):
631+
def get_best_GEV_model_1d(
632+
data, dparams, dparams_ns, covariate, alpha, test, dof=[3, 5]
633+
):
604634
"""Get the best GEV model based on a relative fit test."""
605635
# Calculate the stationary GEV parameters
606636
shape, loc, scale = dparams
@@ -609,7 +639,7 @@ def get_best_GEV_model_1d(data, dparams, dparams_ns, covariate, alpha, test):
609639
L1 = _gev_nllf([-shape, loc, scale], data)
610640
L2 = _gev_nllf([-dparams_ns[0], *dparams_ns[1:]], data, covariate)
611641

612-
result = check_gev_relative_fit(data, L1, L2, test=test, alpha=alpha)
642+
result = check_gev_relative_fit(data, L1, L2, test=test, alpha=alpha, dof=dof)
613643
if not result:
614644
# Return the stationary parameters with no trend
615645
dparams = np.array([shape, loc, 0, scale, 0], dtype="float64")
@@ -742,14 +772,87 @@ def get_return_level(return_period, dparams=None, covariate=None, **kwargs):
742772
return return_level
743773

744774

775+
def get_empirical_return_period(da, event, core_dim="time"):
776+
"""Calculate the empirical return period of an event.
777+
778+
Parameters
779+
----------
780+
da : xarray.DataArray
781+
Input data
782+
event : float
783+
Return level of the event (e.g., 100 mm of rainfall)
784+
core_dim : str, default "time"
785+
The core dimension in which to estimate the return period
786+
787+
Returns
788+
-------
789+
ri : xarray.DataArray
790+
The event recurrence interval (e.g., 100 year return period)
791+
"""
792+
793+
def _empirical_return_period(da, event):
794+
"""Empirical return period of an event (1D)."""
795+
res = ecdf(da)
796+
return 1 / res.sf.evaluate(event)
797+
798+
ri = apply_ufunc(
799+
_empirical_return_period,
800+
da,
801+
event,
802+
input_core_dims=[[core_dim], []],
803+
output_core_dims=[[]],
804+
vectorize=True,
805+
dask="parallelized",
806+
output_dtypes=["float64"],
807+
)
808+
return ri
809+
810+
811+
def get_empirical_return_level(da, return_period, core_dim="time"):
812+
"""Calculate the empirical return period of an event.
813+
814+
Parameters
815+
----------
816+
da : xarray.DataArray
817+
Input data
818+
return_period : float
819+
Return period (e.g., 100 year return period)
820+
core_dim : str, default "time"
821+
The core dimension in which to estimate the return period
822+
823+
Returns
824+
-------
825+
return_level : xarray.DataArray
826+
The event return level (e.g., 100 mm of rainfall)
827+
"""
828+
829+
def _empirical_return_level(da, period):
830+
"""Empirical return level of an event (1D)."""
831+
sf = ecdf(da).sf
832+
probability = 1 - (1 / period)
833+
return np.interp(probability, sf.probabilities, sf.quantiles)
834+
835+
return_level = apply_ufunc(
836+
_empirical_return_level,
837+
da,
838+
return_period,
839+
input_core_dims=[[core_dim], []],
840+
output_core_dims=[[]],
841+
vectorize=True,
842+
dask="parallelized",
843+
output_dtypes=["float64"],
844+
)
845+
return return_level
846+
847+
745848
def aep_to_ari(aep):
746849
"""Convert from aep (%) to ari (years)
747850
748851
Details: http://www.bom.gov.au/water/designRainfalls/ifd-arr87/glossary.shtml
749852
Stolen from https://github.com/climate-innovation-hub/frequency-analysis/blob/master/eva.py
750853
"""
751854

752-
assert aep < 100, "aep to be expressed as a percentage (must be < 100)"
855+
assert np.all(aep < 100), "AEP to be expressed as a percentage (must be < 100)"
753856
aep = aep / 100
754857

755858
return 1 / (-np.log(1 - aep))
@@ -805,7 +908,7 @@ def gev_confidence_interval(
805908
ci_bounds : xarray.DataArray
806909
Confidence intervals with lower and upper bounds along dim 'quantile'
807910
"""
808-
# todo: add max_shape_ratio
911+
# todo: add max_shape_ratio?
809912
# Replace core dim with the one from the fit_kwargs if it exists
810913
core_dim = fit_kwargs.pop("core_dim", core_dim)
811914

@@ -814,8 +917,8 @@ def gev_confidence_interval(
814917
dparams = fit_gev(data, core_dim=core_dim, **fit_kwargs)
815918
shape, loc, scale = unpack_gev_params(dparams)
816919

817-
# Generate random indices for resampling
818920
if bootstrap_method == "parametric":
921+
# Generate bootstrapped data using the GEV distribution
819922
boot_data = apply_ufunc(
820923
genextreme.rvs,
821924
shape,
@@ -830,9 +933,12 @@ def gev_confidence_interval(
830933
boot_data = boot_data.transpose("k", core_dim, ...)
831934

832935
elif bootstrap_method == "non-parametric":
833-
# todo: replace with rng.choice
834-
resample_indices = rng.integers(
835-
0, data[core_dim].size, (n_resamples, data[core_dim].size)
936+
# Resample data with replacements
937+
resample_indices = np.array(
938+
[
939+
rng.choice(data[core_dim].size, size=data[core_dim].size, replace=True)
940+
for i in range(n_resamples)
941+
]
836942
)
837943
indexer = DataArray(resample_indices, dims=("k", core_dim))
838944
boot_data = data.isel({core_dim: indexer})
@@ -1304,6 +1410,80 @@ def plot_stacked_histogram(
13041410
return ax, bins
13051411

13061412

1413+
def spatial_plot_gev_parameters(
1414+
dparams,
1415+
dataset_name=None,
1416+
outfile=None,
1417+
):
1418+
"""Plot spatial maps of GEV parameters (shape, loc, scale).
1419+
1420+
Parameters
1421+
----------
1422+
dparams : xarray.DataArray
1423+
Stationary or nonstationary GEV parameters
1424+
dataset_name : str, optional
1425+
Name of the dataset
1426+
outfile : str, optional
1427+
Output file name
1428+
"""
1429+
# Rename shape parameter
1430+
params = ["shape", *list(dparams.dparams.values[1:])]
1431+
dparams["dparams_new"] = ("dparams", params)
1432+
dparams = (
1433+
dparams.swap_dims({"dparams": "dparams_new"})
1434+
.drop("dparams")
1435+
.rename(dparams_new="dparams")
1436+
)
1437+
params = dparams.dparams.values
1438+
stationary = True if params.size == 3 else False
1439+
1440+
# Adjust subplots for stationary or nonstationary parameters
1441+
nrows = 1 if stationary else 2
1442+
figsize = (14, 6) if stationary else (12, 6)
1443+
1444+
fig, axes = plt.subplots(
1445+
nrows, 3, figsize=figsize, subplot_kw={"projection": PlateCarree()}
1446+
)
1447+
axes = axes.flat
1448+
1449+
# Plot each parameter
1450+
for i, ax in enumerate(axes[: params.size]):
1451+
dparams.sel(dparams=params[i]).plot(
1452+
ax=ax,
1453+
transform=PlateCarree(),
1454+
robust=True,
1455+
cbar_kwargs=dict(label=params[i], fraction=0.038, pad=0.04),
1456+
)
1457+
# Add coastlines and lat/lon labels
1458+
ax.coastlines()
1459+
ax.set_title(f"{params[i]}")
1460+
ax.set_xlabel(None)
1461+
ax.set_ylabel(None)
1462+
ax.xaxis.set_major_formatter(LongitudeFormatter())
1463+
ax.yaxis.set_major_formatter(LatitudeFormatter())
1464+
ax.xaxis.set_minor_locator(AutoMinorLocator())
1465+
ax.yaxis.set_minor_locator(AutoMinorLocator())
1466+
1467+
# Fix xarray facet plot with cartopy bug (v2024.6.0)
1468+
subplotspec = ax.get_subplotspec()
1469+
ax.xaxis.set_visible(True)
1470+
if subplotspec.is_first_col():
1471+
ax.yaxis.set_visible(True)
1472+
1473+
if dataset_name:
1474+
fig.suptitle(f"{dataset_name} GEV parameters", y=0.8 if stationary else 0.99)
1475+
1476+
if not stationary:
1477+
# Hide the empty subplot
1478+
axes[-1].set_visible(False)
1479+
1480+
plt.tight_layout()
1481+
if outfile:
1482+
plt.savefig(outfile, bbox_inches="tight", dpi=200)
1483+
else:
1484+
plt.show()
1485+
1486+
13071487
def _parse_command_line():
13081488
"""Parse the command line for input arguments"""
13091489

@@ -1348,6 +1528,7 @@ def _parse_command_line():
13481528
),
13491529
help="Initial guess method (or estimate) of the GEV parameters",
13501530
)
1531+
13511532
parser.add_argument(
13521533
"--retry_fit",
13531534
action="store_true",
@@ -1389,6 +1570,12 @@ def _parse_command_line():
13891570
action=general_utils.store_dict,
13901571
help="Keyword arguments for opening min_lead file",
13911572
)
1573+
parser.add_argument(
1574+
"--drop_max",
1575+
action="store_true",
1576+
default=False,
1577+
help="Drop the maximum value before fitting",
1578+
)
13921579
parser.add_argument(
13931580
"--ensemble_dim",
13941581
type=str,
@@ -1454,8 +1641,16 @@ def _main():
14541641
# Stack dimensions along new "sample" dimension
14551642
if all([dim in ds[args.var].dims for dim in args.stack_dims]):
14561643
ds = ds.stack(**{"sample": args.stack_dims}, create_index=False)
1644+
ds = ds.chunk(dict(sample=-1)) # fixes CAFE large chunk error
14571645
args.core_dim = "sample"
14581646

1647+
# Drop the maximum value before fitting
1648+
if args.drop_max:
1649+
# Drop the maximum value along each non-core dimension
1650+
ds[args.var] = ds[args.var].where(
1651+
ds[args.var].load() != ds[args.var].max(dim=args.core_dim).load()
1652+
)
1653+
14591654
if args.nonstationary:
14601655
covariate = _format_covariate(ds[args.var], ds[args.covariate], args.core_dim)
14611656
else:

0 commit comments

Comments
 (0)