diff --git a/src/CSET/operators/cell_statistics.py b/src/CSET/operators/cell_statistics.py new file mode 100644 index 000000000..efaab0d90 --- /dev/null +++ b/src/CSET/operators/cell_statistics.py @@ -0,0 +1,408 @@ +# Copyright 2022 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Operators to perform cell statistics diagnostics.""" + +import iris +import iris.analysis +import numpy as np +import itertools +import warnings +import glob +import copy +import iris.coord_categorisation as coord_cat +import dask +import dask.bag as db +from CSET.operators.diagnostics.diag_utils import compute_cell_stats +from CSET.operators.diagnostics.constants import COLOURS +import CSET.operators.diagnostics.config as config +import CSET.operators.diagnostics.utils as utils +import pickle +warnings.filterwarnings("ignore") +iris.FUTURE.datum_support = True + +def cell_statistics(cubelist, input_params_dict): + + ''' + Produces histogram plots objects of cell statistics. + + For each cell statistic plot defined in :mod:`config.PLOTS`, the \ + necessary model (and possibly observational) data is read in from the \ + netCDF file regridded_cubes.nc in each model data directory. It is \ + assumed :func:`regrid.regrid` will have been run previously to create \ + this file. The netCDF file will contain the model fields (and any \ + corresponding gridded observations) required for each cell statistic \ + plot, on one or more common spatial grids. All of the data is read in, \ + a threshold is applied if required, then cells are identified and \ + histograms of the desired cell attribute (e.g. cell size, cell mean \ + value) are constructed. + + Given a particular spatial grid, a set of times is identified for which \ + there is matching data from all of the various model and observational \ + datasets available. The times will be either lead times, validity times \ + or hour of day, depending on how the parameter time_grouping is set. + + At each time, all of the cell statistic histograms valid at that time \ + for a particular model or observational dataset are summed to derive a \ + total histogram. This is repeated for all models and observational \ + datasets. A plot of the total histograms at that time is made and saved. \ + Similar cell statistic histogram plots are produced at all the other times. + + In addition, an overall plot is produced which displays cell statistic \ + histograms for each model and observational dataset constructed by \ + summing histograms across all times. + + The process is then repeated for all spatial grids and threshold values. + + Arguments + _________ + + cubelist + Cublist contains the regridded cubes of all models and cycles. + input_params_dict + input parameters dictionary contains following keys: + Threshols: list containing range of threshols + time_grouping: list contains time groups i.e. forecast_perios, time, hour + cell_attribute: string, can be either "effective_radius_in_km" or "mean_value" + plot_dir: string, output dir to save pickle file containing plot objects + ''' + + thresholds = input_params_dict["thresholds"] + time_grouping = input_params_dict["time_grouping"] + plot_dir = input_params_dict["plot_dir"] + cell_attribute = input_params_dict["cell_attribute"] + + if cubelist == None: + raise Exception("Cublist is empty") + + if cell_attribute == "effective_radius_in_km": + bin_edges = 10**(np.arange(0.0, 3.12, 0.12)) + elif cell_attribute == "mean_value": + bin_edges = 10**(np.arange(0.0, 3.12, 0.12)) + else: + raise Exception("Cell attribute is not correct") + + bin_edges = np.insert(bin_edges, 0, 0) + + # Set up y-axis label + if config.y_axis == "frequency": + y_label = "Frequency" + elif config.y_axis == "relative_frequency": + y_label = "Relative frequency [%]" + else: + y_label = None + + if time_grouping is None: + time_grouping = ["forecast_period", "time", "hour"] + + colours = {} + + for threshold in thresholds: + print('Threshold: {}'.format(threshold)) + + lcube = compute_cell_stats( + cubelist, + threshold, + bin_edges, + cell_attribute + ) + + all_cubes = iris.cube.CubeList( + lcube + ) + + # Work out what models/observations we have and assign colours + # to them + labels = list(set([cube.attributes["data_source"] for + cube in all_cubes])) + for icol, label in enumerate(labels): + if label not in colours: + colours[label] = COLOURS["brewer_paired"][icol] + + # Get a list of all grid details... + grids = [(cube.attributes["grid_spacing"], + cube.attributes["observations"]) for cube in all_cubes] + # ...and extract the unique ones + grids = list(set(grids)) + + # Loop over grids + for (grid_spacing, observations) in grids: + print("=> Grid: {0:s}, {1:s}".format(observations, grid_spacing)) + if observations == "None": + grid_label = "" + else: + grid_label = "_{0:s}".format(observations) + + # Get cubes on this grid + constraint = ( + iris.AttributeConstraint(grid_spacing=grid_spacing) + & iris.AttributeConstraint(observations=observations)) + cubes = all_cubes.extract(constraint) + + + # Loop over different choices for how data is grouped by time + for group in time_grouping: + print(" Time grouping: {0:s}...".format(group)) + # Take a copy of the cubes on this grid + cubes_group = copy.deepcopy(cubes) + + cubes_group = utils.extract_overlapping( + cubes_group, + "forecast_period" + ) + + # Preparation of data for averaging and plotting + if group == "time": + # Identify a unique set of times + times = utils.identify_unique_times(cubes_group, group) + + # Now extract data at these times + time_constraint = iris.Constraint( + coord_values={group: lambda cell: cell.point in + times.points} + ) + cubes_group = cubes_group.extract(time_constraint) + + # Remove other time coordinates to allow a cube merge + # later + for cube in cubes_group: + cube.remove_coord("forecast_reference_time") + cube.remove_coord("forecast_period") + elif group == "forecast_period": + # Identify a unique set of lead times + times = utils.identify_unique_times(cubes_group, group) + + # Remove other time coordinates to allow a cube + # merge later + for cube in cubes_group: + cube.remove_coord("forecast_reference_time") + cube.remove_coord("time") + elif group == "hour": + # Categorise the time coordinate of each cube into + # hours + for cube in cubes_group: + coord_cat.add_categorised_coord( + cube, + "hour", + cube.coord("time"), + utils.hour_from_time, + units="hour" + ) + + # Identify a unique set of times of day + times = utils.identify_unique_times(cubes_group, group) + + # Now extract data at these times + time_constraint = iris.Constraint( + coord_values={group: lambda cell: cell.point + in times.points}) + cubes_group = cubes_group.extract(time_constraint) + + # Remove other time coordinates to allow a cube + # merge later + for cube in cubes_group: + cube.remove_coord("forecast_reference_time") + cube.remove_coord("time") + cube.remove_coord("forecast_period") + + # Remove any duplicate cubes to allow a successful merge + # Note this is typcially because we have the same set of + # observations associated with more than one model + cubes_group = utils.remove_duplicates(cubes_group) + + # Sum cell statistic histograms at each time in parallel + input_params = [(cubes_group, time, iris.analysis.SUM, + None) for time in times] + + n_proc_to_use = min(config.n_proc, len(input_params)) + if n_proc_to_use == 1: + result_list = [ + utils.aggregate_at_time(input_param) + for input_param in input_params] + else: + with dask.config.set(num_workers=n_proc_to_use): + result_list = db.from_sequence(input_params).map( + utils.aggregate_at_time + ).compute() + + # Gather cubes from each process into a cubelist + cubes_group = iris.cube.CubeList( + itertools.chain.from_iterable(result_list)) + # Merge + cubes_group = cubes_group.merge() + # If the number of cases at each time is the same, the + # above merge results in a scalar coordinate representing + # the number of cases. Replace this scalar coordinate with + # an auxillary coordinate that has the same length as the + # time coordinate + cubes_group = utils.repeat_scalar_coord_along_dim_coord( + cubes_group, + "num_cases", + group + ) + + all_plots = [] + for time in times: + + # Extract histogram at this time + time_constraint = iris.Constraint( + coord_values={group: lambda cell: cell.point + in time.points}) + cubes_at_time = cubes_group.extract(time_constraint) + + # Setting up figure title, name and annotation + img_prop = utils.set_title_and_filename(group, + time, + config.short_name, + threshold, + grid_label + ) + annotation = utils.set_annotation(cubes_at_time) + + # New plot for models/observations on this grid at + # this time + plot = {} + plot = { + "title": img_prop["title"], + "filename": img_prop["filename"], + "annotations": annotation, + "field": config.long_name, + "bin_edges": list(bin_edges), + "x_label": "Cell effective radius [km]", + "x_limits": None, + "y_label": y_label, + "y_limits": None, + "cell_attribute": cell_attribute, + "long_name": "{0:s} histogram".format(config.long_name.lower()), + "short_name": "{0:s}_hist".format(config.short_name.lower()), + "plotdir": plot_dir, + "linestyle": "-", + "linewidth": 2, + "markers": "o", + "y_lower": None, + "y_upper": None, + "x_log": True, + "y_log": True + } + + # Loop over models/observations available on this + # grid at this time + lines = [] + for cube in cubes_at_time: + line = {} + # Normalise histogram + if config.y_axis == "relative_frequency": + cube.data = ((100.0 * cube.data) + / np.sum(cube.data, + dtype=np.float64)) + # Add histogram to plot + line["x"] = list(cube.coord(cell_attribute).points) + line["y"] = list(cube.data) + line["label"] = ("{0:s}".format(cube.attributes["data_source"])) + line["colour"] = colours[cube.attributes["data_source"]] + lines.append(line) + plot["lines"] = lines + + all_plots.append(plot) + # Set same axis limits to all plots + all_plots = utils.set_general_axis_limits(all_plots) + # New plot for models/observations on this grid, all + # times combined + all_times_agg_plot = {} + all_times_agg_plot = { + "title": img_prop["title"], + "annotations": annotation, + "field": config.long_name, + "bin_edges": list(bin_edges), + "x_label": "Cell effective radius [km]", + "x_limits": None, + "y_label": y_label, + "y_limits": None, + "cell_attribute": cell_attribute, + "long_name": "{0:s} histogram".format(config.long_name.lower()), + "short_name": "{0:s}_hist".format(config.short_name.lower()), + "plotdir": plot_dir, + "linestyle": "-", + "linewidth": 2, + "markers": "o", + "y_lower": None, + "y_upper": None, + "x_log": True, + "y_log": True + } + filename = ("{0:s}{1:s}_thresh_{2:.1f}_{3:s}" + .format(config.short_name, grid_label,threshold, group)) + all_times_agg_plot["filename"] = filename + + lines = [] + # Loop over models/observations available on this grid + for cube in cubes_group: + line = {} + # Sum all histograms + cube = cube.collapsed(group, iris.analysis.SUM) + + # Normalise histogram + if config.y_axis == "relative_frequency": + cube.data = ((100.0 * cube.data) + / np.sum(cube.data, dtype=np.float64)) + + line["x"] = list(cube.coord(cell_attribute).points) + line["y"] = list(cube.data) + line["label"] = ("{0:s}".format(cube.attributes["data_source"])) + line["colour"] = colours[cube.attributes["data_source"]] + lines.append(line) + all_times_agg_plot["lines"] = lines + all_times_agg_plot["y_limits"] = None + all_times_agg_plot["x_limits"] = None + # Set axis limits + all_times_agg_plot = utils.set_axis_limits(all_times_agg_plot) + + all_plots.append(all_times_agg_plot) + + filename = "data/{}_{}_{}_plots.pkl".format(threshold, grid_spacing, group) + with open(filename, "wb") as pickle_file: + pickle.dump(all_plots, pickle_file) + + print("Cell statistics done") + return 0 + +def plot_cell_statistics(files, output_path): + + ''' + Produces histogram plots of cell statistics. + + Produces plots of each plot object produced by cell_statistics method + + Arguments + _________ + + file + path where all the pickle files (from cell_statistics method) reside + putput_path + plot dir where all plots in PNG will be saved + ''' + for file in files: + # Load the JSON file as a dictionary + with open(file, "rb") as pickle_file: + plots = pickle.load(pickle_file) + input_params = [(plot, output_path) for plot in plots] + n_proc_to_use = min(config.n_proc, len(input_params)) + if n_proc_to_use == 1: + [utils.plot_and_save(input_param)for input_param in input_params] + else: + with dask.config.set(num_workers=n_proc_to_use): + db.from_sequence(input_params).map( + utils.plot_and_save + ).compute() + diff --git a/src/CSET/operators/diagnostics/__init__.py b/src/CSET/operators/diagnostics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/CSET/operators/diagnostics/config.py b/src/CSET/operators/diagnostics/config.py new file mode 100644 index 000000000..c62522b8e --- /dev/null +++ b/src/CSET/operators/diagnostics/config.py @@ -0,0 +1,27 @@ +# Copyright 2022 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +# Cell statistics config + +long_name = "1-hourly mean precipitation rate" +short_name = "1hr_mean_precip" +thresholds = [0.5] +n_proc = 1 +observations = ["GPM", "UK_radar_2km", "Darwin_radar_rain_2.5km"] +x_log = True +y_log = True +time_grouping = ["forecast_period"] +y_axis = "frequency" \ No newline at end of file diff --git a/src/CSET/operators/diagnostics/constants.py b/src/CSET/operators/diagnostics/constants.py new file mode 100644 index 000000000..9205973e6 --- /dev/null +++ b/src/CSET/operators/diagnostics/constants.py @@ -0,0 +1,35 @@ +# Copyright 2022 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Below constants are used in cell statistiscs +CELL_CONNECTIVITY = 2 +M_IN_KM = 1000 +HOUR_IN_SECONDS = 3600 +MAX_TICK_OVERRIDE = 100000 +COLOURS = { + "brewer_paired": ['#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', + '#fb9a99', '#e31a1c', '#fdbf6f', '#ff7f00', + '#cab2d6', '#6a3d9a', '#ffff99', '#b15928'], + "brewer_set3": ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', + '#80b1d3', '#fdb462', '#b3de69', '#fccde5', + '#d9d9d9', '#bc80bd', '#ccebc5', '#ffed6f'], + # List of distinctive colours from the ADAQ toolbox + # See http://rawgit.com/pelson/7248780/raw/ + # 8e571ff02a02aeaacc021edfa7d899b5b0118ea8/colors.html for full list of + # available names + "adaq": ['k', 'r', 'b', 'g', 'orange', 'c', 'm', + 'lawngreen', 'slategrey', 'y', 'limegreen', 'purple', + 'lightgrey', 'indigo', 'darkblue', 'plum', + 'teal', 'violet', 'saddlebrown', 'lightpink'] + } diff --git a/src/CSET/operators/diagnostics/diag_utils.py b/src/CSET/operators/diagnostics/diag_utils.py new file mode 100644 index 000000000..1c54bad40 --- /dev/null +++ b/src/CSET/operators/diagnostics/diag_utils.py @@ -0,0 +1,581 @@ +# Copyright 2022 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import iris +import iris.analysis +import numpy as np +import scipy.ndimage as ndimage +from CSET.operators.diagnostics.constants import M_IN_KM, CELL_CONNECTIVITY +from CSET.operators.diagnostics.utils import guess_bounds, get_spatial_coords, get_non_spatial_coords + +def _neighbours(L, P, NPIXELS, NLINES, method=1): + """ + Returns the neighbours of a pixel, either just neighbours to the side \ + (method 0) or including diagnonal neigbours (method 1) + """ + # Returns neighbours of pixel (L, P) on line L and L-1 only + if method==0: + # Only connected by sides (4 neighbours) + if L==0: + if P==0: + A = (np.array([0]), np.array([0])) + else: + A = (np.array([0, 0]), np.array([P-1, P])) + elif P==0: + A = (np.array([L-1, L]), np.array([0, 0])) + else: + A = (np.array([L-1, L, L]), np.array([P, P-1, P])) + elif method==1: + # Connected through diagonals (8 neighbours) + if L==0: + if P==0: + A = (np.array([0]), np.array([0])) + else: + A = (np.array([0, 0]), np.array([P-1, P])) + elif P==0: + A = (np.array([L-1, L-1, L]), np.array([0, P+1, 0])) + elif P==NPIXELS-1: + A = (np.array([L-1, L-1, L, L]), np.array([P-1, P, P-1, P])) + else: + A = (np.array([L-1, L-1, L-1, L, L]), np.array([P-1, P, P+1, + P-1, P])) + else: + raise ValueError("Error: Method must be 0 or 1") + + return A + +def _neighbours_flip(L, P, NPIXELS, NLINES, method=1): + """ + Returns the neighbours of a pixel, either just neighbours to the side \ + (method 0) or including diagnonal neigbours (method 1) + """ + # Returns neighbours of pixel (L, P) on line L and L+1 only + if method==0: + # Only connected by sides (4 neighbours) + if L==NLINES-1: + if P==NPIXELS-1: + A = (np.array([L]), np.array([P])) + else: + A = (np.array([L, L]), np.array([P, P+1])) + elif P==NPIXELS-1: + A = (np.array([L, L+1]), np.array([P, P])) + else: + A = (np.array([L, L, L+1]), np.array([P, P+1, P])) + elif method==1: + # Connected through diagonals (8 neighbours) + if L==NLINES-1: + if P==NPIXELS-1: + A = (np.array([L]), np.array([P])) + else: + A = (np.array([L, L]), np.array([P, P+1])) + elif P==NPIXELS-1: + A = (np.array([L, L+1, L+1]), np.array([P, P-1, P])) + elif P==0: + A = (np.array([L, L, L+1, L+1]), np.array([P, P+1, P, P+1])) + else: + A = (np.array([L, L, L+1, L+1, L+1]), np.array([P, P+1, + P-1, P, P+1])) + else: + raise ValueError("Error: Method must be 0 or 1") + + return A + +def _connected_object_labelling_reading(bt, minarea, threshold, method=1, + block_radius=3, + missing_data_value=np.nan): + """ + THIS BASIC CODE LOADS A SEVIRI IMAGE AND MAKES A BINARY IMAGE + OUT OF BRIGHTNESS TEMPERATURES USING threshold. + THE BINARY IMAGE IS THEN USED AS INPUT FOR LABELLING STORMS. + ALGORITHM ON PAGE 37-39 HARALICK AND SHAPIRO. + RESULT IS A STRUCT "M" CONTAINING STORMS "S" AND CELLS + (LOCAL np.max/np.min) "C" AND A MATRIX "labelbt" LABELLING EACH + REGION S AND A MATRIX "localnp.max" LABELLING EACH LOCAL np.max C + TRY pcolor(M.labelbt) TO SEE REGIONS + THEN LOOP THROUGH M.C TO ADD LOCATIONS OF LOCAL np.max/np.min + LOOP THROUGH M.S TO REPLACE misval IN STORM PROPERTIES + + bt :: 2D vector of brightness temperatures + np.minarea :: np.minimum size threshold (number of pixels) + of feature + threshold :: brightness temperature threshold to distinguish + features from background + method :: (0) standard direct neighbours (1) include diagonals + block_radius :: to define square region around a value to check \ + if it's a local np.maximum/np.minimum + missing_data_value :: value assigned to missing data + + Returns the label matrix and a list of storms in the bt array + """ + # These commented lines were replaced with the below to deal with + # masked arrays (Chris Short) + #binbt = 0*bt + #binbt[np.where(bt>threshold)] = 1 + #binbt[np.where(bt<=threshold)] = 0 + + # Apply the supplied threshold to the data to generate a binary array + binbt = bt.copy() + if np.ma.is_masked(binbt): + # If the data is masked, replace any masked values with (threshold - 1) + # thus guaranteeing they will be below the threshold and thus set to 0 + # below + binbt = np.ma.filled(binbt, fill_value=(threshold - 1)) + # Set values above and below the threshold to 1 and 0, respectively + indices_above = (binbt > threshold) + indices_below = (binbt <= threshold) + binbt[indices_above] = 1 + binbt[indices_below] = 0 + + labelbt = 0*binbt + + NLINES = np.size(binbt, 0) + NPIXELS = np.size(binbt, 1) + + newlabel = 1 + + # First loop + # Top-down (left-right) scan through binary array + # First set of labels and equivalences assigned (eqtable) + # Local max assigned if local maximum within blockradius + # Local max used to indicate individual cells within a storm + for L in range(0, NLINES): + eqtable = np.zeros((100, 100)) + newline = [] + for P in range(0, NPIXELS): + if binbt[L,P]==1: + A = _neighbours(L, P, NPIXELS, NLINES, method=method) + B = np.where(labelbt[A]>0) + if np.size(B)==0: + NL = newlabel + newlabel = newlabel + 1 + else: + NL = np.min(labelbt[A[0][B], A[1][B]]) + for X in range(0, np.size(B)): + if labelbt[A[0][B[0][X]]][A[1][B[0][X]]]!=NL: + if np.max(eqtable)==0: + eqtable[0][0] = NL + eqtable[0][1] = labelbt[A[0][B[0][X]], A[1][B[0][X]]] + else: + C = np.where(eqtable==labelbt[A[0][B[0][X]], A[1][B[0][X]]]) + if np.size(C)==0: + C = np.where(eqtable==NL) + if np.size(C)==0: + newline = np.min(np.where(eqtable[:][0]==0)) + eqtable[newline][0] = labelbt[A[0][B[0][X]],A[1][B[0][X]]] + eqtable[newline][1] = NL + else: + for Y in range(0, np.size(C,1)): + if eqtable[C[0][Y]][C[1][Y]]==NL: + D = np.where(eqtable[C[0][Y]][:]==0) + eqtable[C[0][Y]][np.min(D)] = labelbt[A[0][B[0][X]],A[1][B[0][X]]] + else: + G = np.where(eqtable==NL) + if np.size(G)==0: + for Y in range(0, np.size(C,1)): + if eqtable[C[0][Y]][C[1][Y]]==labelbt[A[0][B[0][X]],A[1][B[0][X]]]: + D = np.where(eqtable[C[0][Y]][:]==0) + eqtable[C[0][Y]][np.min(D)] = NL + labelbt[L, P] = NL + + if np.max(eqtable)>0: + eqtable[np.where(eqtable==0)] = np.max(labelbt) + 1 + eqlabel = np.zeros(100) + for E in range(0, 100): + eqlabel[E] = np.min(eqtable[E][:]) + for P in range(0, NPIXELS): + if binbt[L, P]==1: + B = np.where(eqtable==labelbt[L, P]) + if np.size(B)>0: + labelbt[L, P] = eqlabel[B[0][0]] + + # Second loop + # Down-up (right-left) scan through label array to find equivalent labels + # and set uniform label per region + for L in range(NLINES-1, -1, -1): + eqtable = np.zeros((100, 100)) + newline = [] + for P in range(NPIXELS-1, -1, -1): + if labelbt[L,P]!=0: + A = _neighbours_flip(L, P, NPIXELS, NLINES, method=method) + B = np.where(labelbt[A]>0) + NL = labelbt[L, P] + for X in range(0, np.size(B)): + if labelbt[A[0][B[0][X]]][A[1][B[0][X]]]!=NL: + if np.max(eqtable)==0: + eqtable[0][0] = NL + eqtable[0][1] = labelbt[A[0][B[0][X]], A[1][B[0][X]]] + else: + C = np.where(eqtable==labelbt[A[0][B[0][X]], A[1][B[0][X]]]) + if np.size(C)==0: + C = np.where(eqtable==NL) + if np.size(C)==0: + newline = np.min(np.where(eqtable[:][0]==0)) + eqtable[newline][0] = labelbt[A[0][B[0][X]], A[1][B[0][X]]] + eqtable[newline][1] = NL + else: + for Y in range(0, np.size(C,1)): + if eqtable[C[0][Y]][C[1][Y]]==NL: + D = np.where(eqtable[C[0][Y]][:]==0) + eqtable[C[0][Y]][np.min(D)] = labelbt[A[0][B[0][X]], A[1][B[0][X]]] + else: + G = np.where(eqtable==NL) + if np.size(G)==0: + for Y in range(0, np.size(C,1)): + if eqtable[C[0][Y]][C[1][Y]]==labelbt[A[0][B[0][X]], A[1][B[0][X]]]: + D = np.where(eqtable[C[0][Y]][:]==0) + eqtable[C[0][Y]][np.min(D)] = NL + + if np.max(eqtable)>0: + eqtable[np.where(eqtable==0)] = np.max(labelbt) + 1 + eqlabel = np.zeros(100) + for E in range(0, 100): + eqlabel[E] = np.min(eqtable[E][:]) + for P in range(0, NPIXELS): + if binbt[L, P]==1: + B = np.where(eqtable==labelbt[L, P]) + if np.size(B)>0: + labelbt[L, P] = eqlabel[B[0][0]] + + maxnum = 0 + for ii in range(1, int(np.max(labelbt))+1): + ind = np.where(labelbt==ii) + if 0 < np.size(ind)/2 < minarea: + labelbt[ind] = 0 + elif np.size(ind)/2 >= minarea: + maxnum = maxnum + 1 + + return labelbt + +def _connected_object_labelling(data, threshold=0.0, min_size=1, connectivity=1): + ''' + Finds connected objects in an input array and assigns them unique labels. + + Arguments: + + * **data** - a :class:`numpy.ndarray` array in which to label objects. + + Keyword arguments: + + * **threshold** - if supplied, only regions where the input data exceeds \ + the threshold will be considered when searching for \ + connected objects. + * **min_size** - minimum size in grids points for connected objects. Must \ + be an integer >= 1. + * **connectivity** - given a particular grid point, all grid points up to \ + a squared distance of connectivity away are considered \ + neighbours. Connectivity may range from 1 (only direct \ + neighbours are considered) to :attr:`data.ndim`. + + Returns: + + * **label_array** - an integer array where each unique object in the input \ + array has a unique label in the returned array. + * **num_objects** - the number of objects found. + ''' + # Apply the supplied threshold to the data to generate a binary array + binary_data = data.copy() + if np.ma.is_masked(binary_data): + # If the data is masked, replace any masked values with (threshold - 1) + # thus guaranteeing they will be below the threshold and thus set to 0 + # below + binary_data = np.ma.filled(binary_data, fill_value=(threshold - 1)) + # Set values above and below the threshold to 1 and 0, respectively + indices_above = (binary_data > threshold) + indices_below = (binary_data <= threshold) + binary_data[indices_above] = 1 + binary_data[indices_below] = 0 + + # Construct a structuring element that defines how the neighbours of + # a grid point are assigned + structure_element = ndimage.morphology.generate_binary_structure( + data.ndim, connectivity) + + # Label distinct (connected) objects in the binary array + label_array, num_objects = ndimage.measurements.label( + binary_data, + structure=structure_element) + + # Throw away any objects smaller than min_size + if min_size < 1: + raise ValueError('"min_size" must be 1 or greater') + elif min_size > 1: + labels = np.unique(label_array) + # Discard the background (which will be labelled as 0) + labels = labels[(labels > 0)] + # Loop over distinct objects + for label in labels: + # Find the indices of the grid points comprising this object + indices = np.where(label_array == label) + # If this object is smaller than min_size, set it as background + if indices[0].size < min_size: + label_array[indices] = 0 + num_objects -= 1 + + return label_array, num_objects + +def _find_cells(cube, threshold=0.0, area_threshold=0.0, connectivity=1): + ''' + Finds connected objects (i.e. cells) in spatial slices of a given \ + :class:`iris.cube.Cube`. + + Arguments: + + * **cube** - an input :class:`iris.cube.Cube` object. + + Keyword arguments: + + * **threshold** - if supplied, only regions where the input data exceeds \ + the threshold will be considered when identifying cells. + * **area_threshold** - minimum area in km^2 that cells must have. + * **connectivity** - given a particular grid point, all grid points up to a \ + squared distance of connectivity away are considered \ + neighbours. Connectivity may range from 1 (only \ + direct neighbours are considered) to \ + :attr:`cube.data.ndim`. + + Returns: + + * **cells** - a :class:`iris.cube.CubeList` of \ + :class:`iris.cube.Cube` objects, each one corresponding to \ + an identified cell. + ''' + # Flag whether to use old cell labelling code from Reading + # TODO To be removed once testing of new code is complete + use_old_code = False + + # Convert input area threshold from km^2 to m^2 + area_threshold = (float(M_IN_KM)**2) * area_threshold + + # Get x, y coordinates of input cube + x_coord, y_coord = get_spatial_coords(cube) + x, y = iris.analysis.cartography.get_xy_grids(cube) + + # Guess x, y coordinate bounds + cube = guess_bounds(cube) + + # Loop over 2D spatial slices of the input cube and find cells in each + # slice + grid_areas = None + cells = iris.cube.CubeList() + coords = get_non_spatial_coords(cube) + for slc in cube.slices_over(coords): + if grid_areas is None: + # Area of grid cells, in m^2 + grid_areas = iris.analysis.cartography.area_weights(slc) + + # Store a list of the non-spatial coordinates for this slice + aux_coords = [(coord, []) for coord in + get_non_spatial_coords(slc)] + + # Find and label cells + if use_old_code: + # Call connected object labelling function from Reading + cell_label_array = _connected_object_labelling_reading(slc.data, + 0.0, + threshold, + method=1) + else: + # Call connected object labelling function based on + # scipy.ndimage.measurements.label + cell_label_array, _ = _connected_object_labelling( + slc.data, + threshold=threshold, + min_size=1, + connectivity=connectivity) + + # Get a list of unique cell labels + cell_labels = np.unique(cell_label_array) + # Discard background (which has a label of 0) + cell_labels = cell_labels[(cell_labels > 0)] + # Loop over cell and store their properties + for cell_label in cell_labels: + # Find the indices of the grid points comprising this cell + cell_indices = np.where(cell_label_array == cell_label) + cell_x = x[cell_indices] + cell_y = y[cell_indices] + cell_values = slc.data[cell_indices] + cell_grid_areas = grid_areas[cell_indices] + + # There should not be any masked data present in cells! + if np.ma.is_masked(cell_values): + raise ValueError("Masked data found in cell {0:d}" + .format(cell_label)) + + # If cell area is less than area_threshold, discard it + # (by setting its label to the background value) + cell_area = np.sum(cell_grid_areas, dtype=np.float64) + if cell_area < area_threshold: + cell_label_array[cell_indices] = 0 + continue + + # Estimate cell centre position + # TODO Is there a better way of doing this? C.O.M? + cell_centre = (np.mean(cell_x, dtype=np.float64), + np.mean(cell_y, dtype=np.float64)) + # Area-weighted mean value in cell + cell_mean = (np.sum((cell_grid_areas * cell_values), + dtype=np.float64) + / cell_area) + # Convert cell area from m^2 to km^2... + cell_area /= (float(M_IN_KM)**2) + # ...and then cell effective radius in km + cell_radius = np.sqrt(cell_area / np.pi) + + # Create an Iris cube to store this cell + cell_cube = iris.cube.Cube( + cell_values, + long_name="{:s} cell".format(cube.name()), + units=cube.units, + attributes=cube.attributes, + cell_methods=cube.cell_methods, + aux_coords_and_dims=aux_coords) + + # Set up x, y coordinates describing the grid points in the cell... + cell_x_coord = iris.coords.AuxCoord( + cell_x, + standard_name=x_coord.standard_name, + long_name=x_coord.long_name, + units=x_coord.units, + bounds=None, + attributes=x_coord.attributes, + coord_system=x_coord.coord_system) + cell_y_coord = iris.coords.AuxCoord( + cell_y, + standard_name=y_coord.standard_name, + long_name=y_coord.long_name, + units=y_coord.units, + bounds=None, + attributes=y_coord.attributes, + coord_system=y_coord.coord_system) + # ...and add them to the cell cube + cell_cube.add_aux_coord(cell_x_coord, 0) + cell_cube.add_aux_coord(cell_y_coord, 0) + + # Set up a coordinate describing the areas of grid cells in + # the cell object... + cell_grid_area_coord = iris.coords.AuxCoord(cell_grid_areas, + long_name="grid_areas", + units="m2") + #...and add it to the cell cube + cell_cube.add_aux_coord(cell_grid_area_coord, 0) + + # Finally add some attriubtes to the cube that describe some + # useful information about the cell + cell_cube.attributes["centre"] = cell_centre + cell_cube.attributes["area_in_km2"] = cell_area + cell_cube.attributes["effective_radius_in_km"] = cell_radius + cell_cube.attributes["mean_value"] = cell_mean + + cells.append(cell_cube) + + return cells + +def cell_attribute_histogram(cube, attribute, bin_edges, bin_centres=None, + threshold=0.0, area_threshold=0.0): + bin_edges = np.asarray(bin_edges) + if bin_centres is None: + bin_centres = 0.5 * (bin_edges[1:] + bin_edges[:-1]) + else: + bin_centres = np.asarray(bin_centres) + + # Check that the number of bin edges and centres is correct + if bin_edges.size != (bin_centres.size + 1): + raise ValueError("Number of bin edges must be one greater than the " + "number of bin centres.") + + # Express histogram bins as an Iris coordinate + bins_as_coord = iris.coords.DimCoord( + bin_centres, + long_name=attribute, + units=cube.units, + coord_system=None, + bounds=np.column_stack((bin_edges[0:-1], bin_edges[1:]))) + + # Loop over 2D spatial slices in cube, find cells and construct histogram + data_min, data_max = None, None + hist_cube = iris.cube.CubeList() + coords = get_non_spatial_coords(cube) + + for slc in cube.slices_over(coords): + # Identify connected cells in this spatial slice + cells = _find_cells( + slc, + threshold=threshold, + area_threshold=area_threshold, + connectivity=CELL_CONNECTIVITY) + + if cells: + # Extract values of the desired cell attribute + cell_attributes = [cell.attributes[attribute] for cell in cells] + + # Store the minimum/maximum values of the cell attribute + if data_min is None or np.min(cell_attributes) < data_min: + data_min = np.min(cell_attributes) + if data_max is None or np.max(cell_attributes) > data_max: + data_max = np.max(cell_attributes) + + # Construct a histogram of the desired cell attribute + hist, _ = np.histogram(cell_attributes, bin_edges) + else: + # Assign zeros to all bins + hist = np.zeros(bin_centres.size).astype(np.int64) + + # Get a list of the non-spatial coordinates for this slice + aux_coords = [(coord, []) for coord in + get_non_spatial_coords(slc)] + + # Construct a cube to hold the cell statistic histogram for this slice + hist_slc = iris.cube.Cube(hist, + long_name=("{0:s} cell {1:s} histogram" + .format(slc.name(), attribute)), + units="no_unit", + attributes=slc.attributes, + cell_methods=slc.cell_methods, + dim_coords_and_dims=[(bins_as_coord, 0)], + aux_coords_and_dims=aux_coords) + + hist_cube.append(hist_slc) + + # If the bins did not fully enclose the data, they will need adjusting + if data_min is not None and data_max is not None: + if (data_min < np.min(bin_edges)) or (data_max > np.max(bin_edges)): + msg = ("Bins do not fully enclose data. Adjust them. Cell {0:s} " + "(min, max)=({1:.4f}, {2:.4f}), bin edge " + "(min, max)=({3:.4f}, {4:.4f})".format(attribute, + data_min, + data_max, + np.min(bin_edges), + np.max(bin_edges))) + raise ValueError(msg) + + # Merge all histograms into a single cube + hist_cube = hist_cube.merge_cube() + + return hist_cube + +def compute_cell_stats(cubes, threshold, BIN_EDGES, cell_attribute): + for cube in cubes: + cube.var_name = None + for coord in cube.coords(): + coord.var_name = None + + + lcube = [cell_attribute_histogram( + cube, + cell_attribute, + BIN_EDGES, + threshold=threshold + ) for cube in cubes] + + return lcube diff --git a/src/CSET/operators/diagnostics/utils.py b/src/CSET/operators/diagnostics/utils.py new file mode 100644 index 000000000..2387a002a --- /dev/null +++ b/src/CSET/operators/diagnostics/utils.py @@ -0,0 +1,789 @@ +# Copyright 2022 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import iris +import os +import datetime +import numpy as np +import pandas as pd +import collections +import errno +import matplotlib.pyplot as plt +import matplotlib.dates as dates +import matplotlib +from CSET.operators.diagnostics.constants import MAX_TICK_OVERRIDE, HOUR_IN_SECONDS +matplotlib.use('Agg') + +def get_spatial_coords(cube): + '''Returns the x, y coordinates of an input :class:`iris.cube.Cube`.''' + # Usual names for spatial coordinates + X_COORD_NAMES = ["longitude", "grid_longitude", + "projection_x_coordinate", "x"] + Y_COORD_NAMES = ["latitude", "grid_latitude", + "projection_y_coordinate", "y"] + + # Get a list of coordinate names for the cube + coord_names = [coord.name() for coord in cube.coords()] + + # Check which x-coordinate we have, if any + x_coords = [coord for coord in coord_names if coord in X_COORD_NAMES] + if len(x_coords) != 1: + raise ValueError("Could not identify a unique x-coordinate in cube") + x_coord = cube.coord(x_coords[0]) + + # Check which y-coordinate we have, if any + y_coords = [coord for coord in coord_names if coord in Y_COORD_NAMES] + if len(y_coords) != 1: + raise ValueError("Could not identify a unique y-coordinate in cube") + y_coord = cube.coord(y_coords[0]) + + return [x_coord, y_coord] + +def get_non_spatial_coords(cube): + ''' + Returns a list of the non-spatial coordinates of an input \ + :class:`iris.cube.Cube`. + ''' + # Get a list of the cube coordinates + coords = cube.coords() + # Get the spatial coordinates of the cube + x_coord, y_coord = get_spatial_coords(cube) + # Remove the spatial coordinates from the list of coordinates + coords.remove(x_coord) + coords.remove(y_coord) + return coords + +def extract_overlapping(cubelist, coord_name): + ''' + Extracts regions from cubes in a :class:`iris.cube.CubeList` such that \ + the specified coordinate is the same across all cubes. + + Arguments: + + * **cubelist** - an input :class:`iris.cube.CubeList`. + * **coord_name** - a string specifying the name of the coordinate \ + over which to perform the extraction. + + Returns a :class:`iris.cube.CubeList` where the coordinate corresponding \ + to coord_name is the same for all cubes. + ''' + # Build a list of all Cell instances for this coordinate by + # looping through all cubes in the supplied cubelist + all_cells = [] + for cube in cubelist: + for cell in cube.coord(coord_name).cells(): + all_cells.append(cell) + + # Work out which coordinate Cell instances are common across + # all cubes in the cubelist... + cell_counts = collections.Counter(all_cells) + #unique_cells = cell_counts.keys() + unique_cells = list(cell_counts.keys()) + unique_cell_counts = list(cell_counts.values()) + num_cubes = len(cubelist) + common_cells = [unique_cells[i] for i, count in + enumerate(unique_cell_counts) if count==num_cubes] + # ...and use these to subset the cubes in the cubelist + constraint = iris.Constraint( + coord_values={coord_name: lambda cell: cell in common_cells}) + + cubelist = iris.cube.CubeList([cube.extract(constraint) + for cube in cubelist]) + + return cubelist + +def remove_duplicates(cubelist): + ''' + Removes any duplicate :class:`iris.cube.Cube` objects from an \ + :class:`iris.cube.CubeList`. + ''' + # Nothing to do if the cubelist is empty + if not cubelist: + return cubelist + # Build up a list of indices of the cubes to remove because they are + # duplicated + indices_to_remove = [] + for i in range(len(cubelist) - 1): + cube_i = cubelist[i] + for j in range(i + 1, len(cubelist)): + cube_j = cubelist[j] + if cube_i == cube_j: + if j not in indices_to_remove: + indices_to_remove.append(j) + # Only keep unique cubes + cubelist = iris.cube.CubeList([cube for index, cube in + enumerate(cubelist) if index + not in indices_to_remove]) + return cubelist + +def identify_unique_times(cubelist, time_coord_name): + ''' + Given a :class:`iris.cube.CubeList`, this finds the set of unique times \ + which occur across all cubes in the cubelist. + + Arguments: + + * **cubelist** - a :class:`iris.cube.CubeList` of :class:`iris.cube.Cube` \ + objects. + * **time_coord_name** - the name of the time coordinate to select, \ + typically "time", "forecast_period" or "hour". + + Returns: + + * **time_coord** - an :class:`iris.coords.Coord` instance containing the \ + unique times that occur across the cubes in the \ + input cubelist. + ''' + times = [] + time_unit = None + # Loop over cubes + for cube in cubelist: + # Extract the desired time coordinate from the cube + time_coord = cube.coord(time_coord_name) + + # Get the units for the specifed time coordinate + if time_unit is None: + time_unit = time_coord.units + + # Store the time coordinate points + times.extend(time_coord.points) + + # Construct a list of unique times... + times = sorted(list(set(times))) + # ...and store them in a new time coordinate + time_coord = iris.coords.DimCoord(times, units=time_unit) + time_coord.rename(time_coord_name) + + return time_coord + +def hour_from_time(coord, point): + ''' + Category function to calculate the hour given a time, for use in \ + :func:`iris.coord_categorisation.add_categorised_coord`. + ''' + time = coord.units.num2date(point) + day_start = datetime.datetime(time.year, time.month, time.day) + seconds_since_day_start = (time - day_start).total_seconds() + hours_since_day_start = (seconds_since_day_start + / float(HOUR_IN_SECONDS)) + return hours_since_day_start + +def remove_cell_method(cube, cell_method): + ''' + Removes the supplied :class:`iris.coords.CellMethod` from an input + :class:`iris.cube.Cube`, then returns the cube. + ''' + cell_methods = [cm for cm in cube.cell_methods if cm != cell_method] + cube.cell_methods = () + for cm in cell_methods: + cube.add_cell_method(cm) + return cube + +def aggregate_at_time(input_params): + ''' + Extracts data valid at a given time from each cube in a list of cubes, \ + then performs an aggregation operation (e.g. mean) across this data. + + Arguments (passed in as a tuple to allow parallelisation): + + * **input_params** - a four-element tuple consisting of: + + * **cubes** - a :class:`iris.cube.CubeList` holding the \ + :class:`iris.cube.Cube` objects to process. + * **time_coord** - the time at which the aggregation should be performed, \ + supplied as an :class:`iris.coords.Coord` object. + * **aggregator** - the aggregator to use, which can be any from \ + :mod:`iris.analysis`. + * **percentile** - the value of the percentile rank at which to extract \ + values, if the chosen aggregator is \ + :class:`iris.analysis.PERCENTILE`. For other \ + aggregators this not used. + + Returns: + + * **aggregated_cubes** - an :class:`iris.cube.CubeList` of \ + :class:`iris.cube.Cube` objects holding the \ + aggregated data. + ''' + # Unpack input parameters tuple + cubes = input_params[0] + time_coord = input_params[1] + aggregator = input_params[2] + # TODO: Can we improve the handling of this with keyword arguments? + percentile = input_params[3] + + # Check the supplied time coordinate to make sure it corresponds to a + # single time only + if len(time_coord.points) != 1: + raise ValueError("Time coordinate should specify a single time only") + + # Remove any duplicate cubes in the input cubelist otherwise this + # will break the aggregation + cubes = remove_duplicates(cubes) + + # Name of the supplied time coordinate + time_coord_name = time_coord.name() + + # Extract cubes matching the time specifed by the supplied time coordinate + time_constraint = iris.Constraint(coord_values={time_coord_name: + lambda cell: cell.point in + time_coord.points}) + cubes_at_time = cubes.extract(time_constraint) + + # Add a temporary "number" coordinate to uniquely label the different + # data points at this time. + # An example of when there can be multiple data points at the time of + # interest is if the time coordinate represents the hour of day. + number = 0 + numbered_cubes = iris.cube.CubeList() + for cube in cubes_at_time: + for slc in cube.slices_over(time_coord_name): + number_coord = iris.coords.AuxCoord(number, long_name='number') + slc.add_aux_coord(number_coord) + numbered_cubes.append(slc) + number += 1 + cubes_at_time = numbered_cubes + + # Merge + cubes_at_time = cubes_at_time.merge() + + # For each cube in the cubelist, aggregate over all cases at this time + # using the supplied aggregator + aggregated_cubes = iris.cube.CubeList() + for cube in cubes_at_time: + # If there was only a single data point at this time, then "number" + # will be a scalar coordinate. If so, make it a dimension coordinate + # to allow collapsing below + if not cube.coord_dims("number"): + cube = iris.util.new_axis(cube, scalar_coord="number") + + # Store the total number of data points found at this time + num_cases = cube.coord("number").points.size + num_cases_coord = iris.coords.AuxCoord(num_cases, long_name='num_cases') + cube.add_aux_coord(num_cases_coord) + + # Do aggregation across the temporary "number" coordinate + if isinstance(aggregator, type(iris.analysis.PERCENTILE)): + cube = cube.collapsed("number", aggregator, percent=percentile) + else: + cube = cube.collapsed("number", aggregator) + + # Now remove the "number" coordinate... + cube.remove_coord("number") + #...and associated cell method + cell_method = iris.coords.CellMethod(aggregator.name(), coords="number") + cube = remove_cell_method(cube, cell_method) + + aggregated_cubes.append(cube) + + return aggregated_cubes + +def repeat_scalar_coord_along_dim_coord(cubelist, scalar_coord_name, + dim_coord_name): + ''' + For each :class:`iris.cube.Cube` in a given :class:`iris.cube.CubeList`, \ + this extends (by repetition) a specified scalar coordinate along the \ + dimension corresponding to a specified dimension coordinate. + ''' + for cube in cubelist: + scalar_coord = cube.coord(scalar_coord_name) + + # Check the coordinate referenced by scalar_coord_name is indeed + # a scalar coordinate. Otherwise nothing to do. + if scalar_coord.points.size == 1: + # Get the data value held by the scalar coordinate... + scalar_coord_val = scalar_coord.points[0] + + # ...then remove it from the cube + cube.remove_coord(scalar_coord) + + # Extract the dimension coordinate matching dim_coord_name + dim_coord = cube.coord(dim_coord_name) + # Get the dimension spanned by this dimension coordinate... + dim = cube.coord_dims(dim_coord_name)[0] + # ...and its length + dim_size = dim_coord.points.size + + # Construct an auxillary coordinate by replicating the data + # value from the scalar coordinate to match the size of the + # specified dimension coordinate + scalar_coord = iris.coords.AuxCoord( + np.repeat(scalar_coord_val, dim_size), + long_name=scalar_coord_name) + + # Add the new auxillary coordinate to the cube + cube.add_aux_coord(scalar_coord, dim) + + return cubelist + +def set_title_and_filename(group, time, short_name, threshold, grid_label): + img_info = {} + + if group == "forecast_period": + img_info["title"] = "T+{0:.1f}".format(time.points[0]) + img_info["filename"] = ("{0:s}{1:s}_thresh_{2:.1f}_T{3:.1f}" + .format(short_name, + grid_label, threshold, + time.points[0])) + elif group == "time": + time_unit = time.units + datetime = time_unit.num2date(time.points[0]) + img_info["title"] = "{0:%Y/%m/%d} {1:%H%M}Z".format(datetime, + datetime) + img_info["filename"] = ("{0:s}{1:s}_thresh_{2:.1f}_" + "{3:%Y%m%d}_{4:%H%M}Z".format( + short_name, grid_label, + threshold, datetime, datetime)) + else: + img_info["title"] = "{0:.1f}Z".format(time.points[0]) + img_info["filename"] = ("{0:s}{1:s}_thresh_{2:.1f}_{3:.1f}Z" + .format(short_name, + grid_label, threshold, + time.points[0])) + + return img_info + +def get_valid_data(x, y, x_log=False, y_log=False): + ''' + Returns the non-NaN elements of the input arrays. If the keyword log is \ + set, only non-NaN elements greater than zero are selected. + ''' + # Check input + if x is None or y is None: + return None, None + elif not isinstance(x, np.ndarray): + raise ValueError("Input x must be a numpy array") + elif not isinstance(y, np.ndarray): + raise ValueError("Input y must be a numpy array") + + # Remove any NaNs from input x array + # We use pandas.isnull instead of numpy.isnan here in case + # arr contains datetime objects + indices = np.where(~pd.isnull(x)) + x = x[indices] + y = y[indices] + if x_log: + # only keep x values that are greater than zero + indices = np.where(x>0) + y = y[indices] + x = x[indices] + + # Now do the same for y + indices = np.where(~pd.isnull(y)) + y = y[indices] + x = x[indices] + if y_log: + # Only keep y values that are greater than zero + indices = np.where(y>0) + y = y[indices] + x = x[indices] + + # Return None if no valid data was extracted + if x.size == 0 or y.size == 0: + x = None + y = None + + return x, y + +def _mkdir_p(path): + '''Makes a directory, mimicking mkdir -p behaviour.''' + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + pass + else: raise + +def set_axis_limits(plot): + ''' + #Guesses appropriate limits for the x and y axes using the coordinate \ + #values of each line in the :class:`LinePlot`. Nothing is done if x \ + #and y limits have been manually specifed. + ''' + + # First set up x-axis limits... + if plot["x_limits"] is None: + xmin, xmax = None, None + x_log_possible = False + for line in plot["lines"]: + + # Look for valid data to plot + x, y = get_valid_data( + np.array(line["x"]), + np.array(line["y"]), + x_log=plot["x_log"], + y_log=plot["y_log"] + ) + + # If no data can be plotted, move on + if x is None or y is None: + continue + + x_log_possible = True + # This is useful to trim off lots of leading/trailing zeros + # (e.g. for histograms with a wider range of bins than + # necessary) + x = x[np.nonzero(y)] + + # Update minimum x values + if xmin is None or np.min(x) < xmin: + xmin = np.min(x) + + # Update maximum x values + if xmax is None or np.max(x) > xmax: + xmax = np.max(x) + + # Turn off logarithmic x axis if it is not possible + #if plot["x_log"] and not x_log_possible: + # plot["x_log"] = False + + if xmin == xmax: + x_limits = None + elif xmin is None or xmax is None: + x_limits = None + else: + x_limits = (xmin, xmax) + + # Store the x axis limits + plot["x_limits"] = x_limits + + # ...then set up y-axis limits + if plot["y_limits"] is None: + ymin, ymax = None, None + y_log_possible = False + for line in plot["lines"]: + _, y = get_valid_data( + np.array(line["x"]), + np.array(line["y"]), + x_log=plot["x_log"], + y_log=plot["y_log"] + ) + _, y_lower = get_valid_data( + np.array(line["x"]), + plot["y_lower"], + x_log=plot["x_log"], + y_log=plot["y_log"] + ) + _, y_upper = get_valid_data( + np.array(line["x"]), + plot["y_upper"], + x_log=plot["x_log"], + y_log=plot["y_log"] + ) + + # If no data can be plotted, move on + if y is None: + continue + + # Update minimum y values, using any lower bounds on y + if y_lower is not None: + if ymin is None or np.min(y_lower) < ymin: + ymin = np.min(y_lower) + else: + if ymin is None or np.min(y) < ymin: + ymin = np.min(y) + + # Update maximum y values, using any upper bounds on y + if y_upper is not None: + if ymax is None or np.max(y_upper) > ymax: + ymax = np.max(y_upper) + else: + if ymax is None or np.max(y) > ymax: + ymax = np.max(y) + + if ymin == ymax: + y_limits = None + elif ymin is None or ymax is None: + y_limits = None + else: + y_limits = (ymin, ymax) + + # Store the y axis limits + plot["y_limits"] = y_limits + + return plot + +def plotfig(plot): + ''' + Actually produce the :class:`LinePlot`. Returns an instance of \ + :class:`matplotlib.figure.Figure` for the plot, or None if no plot \ + was made. + ''' + + if len(plot["lines"]) == 0: + #print("No lines to plot") + return None + + # Open a new figure and set up axes + fig, axes = plt.subplots() + # Loop over lines on this plot + for line in plot["lines"]: + # Take a copy of the data for this line + x = np.copy(line["x"]) + y = np.copy(line["y"]).astype(np.float64) + + if plot["x_log"]: + x[np.where(x<=0)] = np.nan + if plot["y_log"]: + y[np.where(y<=0)] = np.nan + + axes.plot( + x, + y, + linestyle=plot["linestyle"], + marker=plot["markers"], + linewidth=plot["linewidth"], + color=line["colour"], + label=line["label"] + ) + + # set plot x and y axis limits + #plot = set_axis_limits(plot) + # Add user-specified text annotations + + if plot["annotations"] is not None: + for annotation in plot["annotations"]: + axes.annotate(annotation["text"], + xy=annotation["xy"], + xycoords="axes fraction", + fontsize=7 + ) + + # Set up Logarithmic axes if requested + if plot["x_log"]: + axes.set_xscale("log") + if plot["y_log"]: + axes.set_yscale("log") + + # Apply axis limits + if plot["x_limits"] is not None: + axes.set_xlim(plot["x_limits"]) + if plot["y_limits"] is not None: + axes.set_ylim(plot["y_limits"]) + # Apply tick formatting + x_is_time = True + for line in plot["lines"]: + if not all([isinstance(x, datetime.datetime) for x in line["x"]]): + x_is_time = False + + if x_is_time: + # Work out length of x-axis in days + first_date, last_date = None, None + for line in plot["lines"]: + if first_date is None or np.min(line["x"]) < first_date: + first_date = np.min(line["x"]) + if last_date is None or np.max(line["x"]) > last_date: + last_date = np.max(line["x"]) + num_days = (last_date - first_date).days + print('Number of days: ', num_days) + # Set the spacing of ticks on the x-axis in such a way as to + # prevent it becoming too crowded as the x-axis gets longer + # (i.e. when analysing longer trials) + x_tick_interval = (num_days / 10) + 1 + locator = dates.DayLocator(interval=int(x_tick_interval)) + + locator.MAXTICKS = MAX_TICK_OVERRIDE + axes.xaxis.set_major_locator(locator) + axes.xaxis.set_major_formatter(dates.DateFormatter('%HZ\n%d/%m')) + if x_tick_interval < 3: + locator = dates.HourLocator(byhour=[0, 6, 12, 18]) + elif x_tick_interval < 6: + locator = dates.HourLocator(byhour=[0, 12]) + else: + locator = dates.DayLocator(interval=1) + locator.MAXTICKS = MAX_TICK_OVERRIDE + axes.xaxis.set_minor_locator(locator) + + # Add legend + legend = axes.legend(loc='best') + legend.draw_frame(False) + + # Add title and axis labels + if plot["title"] is not None: + axes.set_title(plot["title"]) + if plot["x_label"] is not None: + axes.set_xlabel(plot["x_label"]) + if plot["y_label"] is not None: + axes.set_ylabel(plot["y_label"]) + + return fig + +def save(plot, plotdir=None, db_file=None): + ''' + Save the :class:`LinePlot` to a file. + + Keyword arguments: + + * **plotdir** - the name of the directory where the plot will be saved. + * **db_file** - the name of the database file used to store image \ + metadata for this plot. + + Returns: + + * The name of the file where the plot was saved. + ''' + # Set a default plot directory if none given + if plotdir is None: + plotdir = "." + + # Remove any trailing slash + if plotdir[-1] == "/": + plotdir = plotdir[:-1] + + # Create plot directory if not present + _mkdir_p(plotdir) + + # Generate a standard file name if none supplied + if plot["filename"] is None: + plot["filename"] = "lineplot" + filename = "{0:s}/{1:s}.png".format( + plotdir, + plot["filename"] + ) + + # Database file to store all image metadata in + if db_file is None: + db_file = "{0:s}/imt_db.db".format(plotdir) + + # Now save the figure using ImageMetaTag + plt.savefig(filename) + # Close the figure + plt.close() + + return filename + +def plot_and_save(input_params): + + ''' + Creates and saves a plot. + + Arguments (passed in as a tuple to allow parallelisation): + + * **input_params** - a two-element tuple consisting of: + + * **plot** - an instance of :class:`AreaAveragePlot`, \ + :class:`HistogramPlot` or :class:`CellHistogramPlot` \ + defining the plot to be made. + * **plot_dir** - the name of the directory where the plot will be saved. + + Returns: + + * **db_file** - the name of the database file used to store image \ + metadata for this plot, or None if no plot could be made. + ''' + # Unpack input parameters tuple + plot = input_params[0] + plot_dir = input_params[1] + # Produce the plot + fig = plotfig(plot) + + if fig is None: + db_file = None + else: + # Name of the database file to store all image metadata + proc_id = os.getpid() + db_file = "{0:s}/imt_tmp_db_{1:d}.db".format(plot_dir, proc_id) + # Save the plot + save(plot, plotdir=plot_dir, db_file=db_file) + +def guess_bounds(cube): + ''' + Takes an input :class:`iris.cube.Cube`, guesses bounds on the x, y \ + coordinates, then returns the cube. Such bounds are often required by \ + regridding algorithms. + ''' + # Loop over spatial coordinates + for axis in ["x", "y"]: + coord = cube.coord(axis=axis) + # Check that this is not a variable resolution grid + # TODO: Does bounds really not work with variable resolution? AVD + # seem to think it won't... + try: + _ = iris.util.regular_step(coord) + except: + raise ValueError("Cannot guess bounds for a variable " + "resolution grid") + # Guess bounds if there aren't any + if coord.bounds is None: + coord.guess_bounds() + return cube + +def set_annotation(cubes): + num_cases = [] + for cube in cubes: + num_cases.extend(cube.coord("num_cases").points) + num_cases = list(set(num_cases)) + if len(num_cases) != 1: + annotation = None + else: + # Set up a text annotation showing the number of + # cases that went into constructing histograms + # at this time + annotation = [ + {"text": "Number of cases: {:d}".format( + num_cases[0]), + "xy": (0.0, 1.02)} + ] + return annotation + +def set_general_axis_limits(plots): + ''' + Guesses appropriate limits for the x and y axes using the extent \ + of each plot in the :class:`HistogramPlot`. + + Keyword arguments: + + * **plots** - only apply the axis limits to this subset of plots. + ''' + # Set the axis limits for each plot + for plot in plots: + plot = set_axis_limits(plot) + + x_limits, y_limits = None, None + + # Work out the minimum/maximum x values based on all of the plots + xmin, xmax = None, None + for plot in plots: + if plot["x_limits"] is not None: + if xmin is None or np.min(plot["x_limits"]) < xmin: + xmin = np.min(plot["x_limits"]) + if xmax is None or np.max(plot["x_limits"]) > xmax: + xmax = np.max(plot["x_limits"]) + + if xmin is None or xmax is None: + x_limits = None + else: + x_limits = (xmin, xmax) + + # Work out the minimum/maximum y values based on all of the plots + ymin, ymax = None, None + for plot in plots: + if plot["y_limits"] is not None: + if ymin is None or np.min(plot["y_limits"]) < ymin: + ymin = np.min(plot["y_limits"]) + if ymax is None or np.max(plot["y_limits"]) > ymax: + ymax = np.max(plot["y_limits"]) + if ymin is None or ymax is None: + y_limits = None + else: + y_limits = (ymin, ymax) + + # Now impose the same axis limits on all of the plots + for plot in plots: + plot["x_limits"] = x_limits + plot["y_limits"] = y_limits + + return plots + + diff --git a/tests/test_cell_statistics.py b/tests/test_cell_statistics.py new file mode 100644 index 000000000..bc7ef12d5 --- /dev/null +++ b/tests/test_cell_statistics.py @@ -0,0 +1,73 @@ +# Copyright 2023 Met Office and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for cell statistics functionality across CSET.""" + +from pathlib import Path +import pytest +import CSET.operators.cell_statistics as cs +import glob +import iris + +def test_cell_effective_radius(): + """Happy case for cell effective raduis.""" + files = glob.glob("/tests/test_data/precip_cubes*.nc") + cubelist = iris.load(files) + input_params = { + "thresholds": [0.5], + "time_grouping": "forecast_period", + "cell_attribute": "effective_radius_in_km", + "plot_dir": "/test/test_data/" + } + + assert cs.cell_statistics(cubelist,input_params) == 0 + + +def test_cell_mean_size(): + """Happy case for cell mean size.""" + files = glob.glob("/tests/test_data/precip_cubes*.nc") + cubelist = iris.load(files) + input_params = { + "thresholds": [0.5], + "time_grouping": "forecast_period", + "cell_attribute": "mean_value", + "plot_dir": "/test/test_data/" + } + + assert cs.cell_statistics(cubelist,input_params) == 0 + +def test_cubelist_is_none(): + """Test exception for null cubelist""" + cubelist = None + input_params = { + "thresholds": [0.5], + "time_grouping": "forecast_period", + "cell_attribute": "mean", + "plot_dir": "/test/test_data/" + } + with pytest.raises(Exception): + cs.cell_statistics(cubelist, input_params) + +def test_cell_attribute_unknown(): + """Test exception for unknown cell attribute""" + files = glob.glob("/tests/test_data/precip_cubes*.nc") + cubelist = iris.load(files) + input_params = { + "thresholds": [0.5], + "time_grouping": "forecast_period", + "cell_attribute": "mean_cell", + "plot_dir": "/test/test_data/" + } + with pytest.raises(Exception): + cs.cell_statistics(cubelist, input_params) \ No newline at end of file diff --git a/tests/test_data/precip_cubes_mibd567.nc b/tests/test_data/precip_cubes_mibd567.nc new file mode 100644 index 000000000..bbad45c16 Binary files /dev/null and b/tests/test_data/precip_cubes_mibd567.nc differ diff --git a/tests/test_data/precip_cubes_mibe419.nc b/tests/test_data/precip_cubes_mibe419.nc new file mode 100644 index 000000000..60613ecab Binary files /dev/null and b/tests/test_data/precip_cubes_mibe419.nc differ