Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions modelskill/comparison/_comparer_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def taylor(

def residual_hist(
self, bins=100, title=None, color=None, figsize=None, ax=None, **kwargs
) -> matplotlib.axes.Axes:
) -> matplotlib.axes.Axes | list[matplotlib.axes.Axes]:
"""plot histogram of residual values

Parameters
Expand All @@ -776,20 +776,67 @@ def residual_hist(
residual color, by default "#8B8D8E"
figsize : tuple, optional
figure size, by default None
ax : matplotlib.axes.Axes, optional
ax : matplotlib.axes.Axes | list[matplotlib.axes.Axes], optional
axes to plot on, by default None
**kwargs
other keyword arguments to plt.hist()

Returns
-------
matplotlib.axes.Axes
matplotlib.axes.Axes | list[matplotlib.axes.Axes]
"""
cmp = self.comparer

if cmp.n_models == 1:
return self._residual_hist_one_model(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=ax,
mod_name=cmp.mod_names[0],
**kwargs,
)

if ax is not None and len(ax) != len(cmp.mod_names):
raise ValueError("Number of axes must match number of models")

axs = ax if ax is not None else [None] * len(cmp.mod_names)

for i, mod_name in enumerate(cmp.mod_names):
cmp_model = cmp.sel(model=mod_name)
ax_mod = cmp_model.plot.residual_hist(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=axs[i],
**kwargs,
)
axs[i] = ax_mod

return axs

def _residual_hist_one_model(
self,
bins=100,
title=None,
color=None,
figsize=None,
ax=None,
mod_name=None,
**kwargs,
) -> matplotlib.axes.Axes:
"""Residual histogram for one model only"""
_, ax = _get_fig_ax(ax, figsize)

default_color = "#8B8D8E"
color = default_color if color is None else color
title = f"Residuals, {self.comparer.name}" if title is None else title
title = (
f"Residuals, Observation: {self.comparer.name}, Model: {mod_name}"
if title is None
else title
)
ax.hist(self.comparer._residual, bins=bins, color=color, **kwargs)
ax.set_title(title)
ax.set_xlabel(f"Residuals of {self.comparer._unit_text}")
Expand Down
14 changes: 12 additions & 2 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,8 @@ def test_to_dataframe_tc(tc):

# ======================== plotting ========================

PLOT_FUNCS_RETURNING_MANY_AX = ["scatter", "hist", "residual_hist"]


@pytest.fixture(
params=[
Expand All @@ -727,11 +729,20 @@ def test_to_dataframe_tc(tc):
def pc_plot_function(pc, request):
func = getattr(pc.plot, request.param)
# special cases requiring a model to be selected
if request.param in ["scatter", "hist", "residual_hist"]:
if request.param in PLOT_FUNCS_RETURNING_MANY_AX:
func = getattr(pc.sel(model=0).plot, request.param)
return func


@pytest.mark.parametrize("kind", PLOT_FUNCS_RETURNING_MANY_AX)
def test_plots_returning_multiple_axes(pc, kind):
n_models = 2
func = getattr(pc.plot, kind)
ax = func()
assert len(ax) == n_models
assert all(isinstance(a, plt.Axes) for a in ax)


def test_plot_returns_an_object(pc_plot_function):
obj = pc_plot_function()
assert obj is not None
Expand Down Expand Up @@ -824,7 +835,6 @@ def test_plots_directional(pt_df):


def test_from_matched_track_data():

df = pd.DataFrame(
{
"lat": [55.0, 55.1],
Expand Down