From 1ba64c59802df387dc502672fdfd8703a8d01e38 Mon Sep 17 00:00:00 2001 From: Naoki Kanazawa Date: Wed, 21 Feb 2024 23:48:00 +0900 Subject: [PATCH 1/4] Fix curve plotter bug and add more reliable test --- .../visualization/plotters/base_plotter.py | 2 +- .../visualization/plotters/curve_plotter.py | 71 ++++---- test/visualization/test_plotter.py | 162 +++++++++++++++++- test/visualization/test_utils.py | 81 +++++++++ 4 files changed, 276 insertions(+), 40 deletions(-) diff --git a/qiskit_experiments/visualization/plotters/base_plotter.py b/qiskit_experiments/visualization/plotters/base_plotter.py index 2810184b8d..13afd64996 100644 --- a/qiskit_experiments/visualization/plotters/base_plotter.py +++ b/qiskit_experiments/visualization/plotters/base_plotter.py @@ -198,7 +198,7 @@ def data_keys_for(self, series_name: SeriesName) -> List[str]: def data_for( self, series_name: SeriesName, data_keys: Union[str, List[str]] - ) -> Tuple[Optional[Any]]: + ) -> Tuple[Optional[Any], ...]: """Returns data associated with the given series. The returned tuple contains the data, associated with ``data_keys``, in the same diff --git a/qiskit_experiments/visualization/plotters/curve_plotter.py b/qiskit_experiments/visualization/plotters/curve_plotter.py index 033ba81a54..9fba02a92f 100644 --- a/qiskit_experiments/visualization/plotters/curve_plotter.py +++ b/qiskit_experiments/visualization/plotters/curve_plotter.py @@ -12,11 +12,11 @@ """Plotter for curve fits, specifically from :class:`.CurveAnalysis`.""" from typing import List +import numpy as np from uncertainties import UFloat from qiskit_experiments.curve_analysis.utils import analysis_result_to_repr from qiskit_experiments.framework import Options - from .base_plotter import BasePlotter @@ -119,49 +119,44 @@ def _plot_figure(self): plotted_formatted_data = False if self.data_exists_for(ser, ["x_formatted", "y_formatted", "y_formatted_err"]): x, y, yerr = self.data_for(ser, ["x_formatted", "y_formatted", "y_formatted_err"]) - self.drawer.scatter(x, y, y_err=yerr, name=ser, zorder=2, legend=True) - plotted_formatted_data = True + if x is not None and y is not None: + self.drawer.scatter(x, y, y_err=yerr, name=ser, zorder=2, legend=True) + plotted_formatted_data = True # Scatter plot if self.data_exists_for(ser, ["x", "y"]): x, y = self.data_for(ser, ["x", "y"]) - options = { - "zorder": 1, - } - # If we plotted formatted data, differentiate scatter points by setting normal X-Y - # markers to gray. - if plotted_formatted_data: - options["color"] = "gray" - # If we didn't plot formatted data, the X-Y markers should be used for the legend. We add - # it to ``options`` so it's easier to pass to ``scatter``. - if not plotted_formatted_data: - options["legend"] = True - self.drawer.scatter( - x, - y, - name=ser, - **options, - ) - - # Line plot for fit + if x is not None and y is not None: + options = { + "zorder": 1, + } + # If we plotted formatted data, differentiate scatter points + # by setting normal X-Y markers to gray. + if plotted_formatted_data: + options["color"] = "gray" + # If we didn't plot formatted data, the X-Y markers should be used for the legend. + # We add it to ``options`` so it's easier to pass to ``scatter``. + if not plotted_formatted_data: + options["legend"] = True + self.drawer.scatter(x, y, name=ser, **options) + + # Line and confidence interval plot for fit if self.data_exists_for(ser, ["x_interp", "y_interp"]): x, y = self.data_for(ser, ["x_interp", "y_interp"]) - self.drawer.line(x, y, name=ser, zorder=3) - - # Confidence interval plot - if self.data_exists_for(ser, ["x_interp", "y_interp", "y_interp_err"]): - x, y_interp, y_interp_err = self.data_for( - ser, ["x_interp", "y_interp", "y_interp_err"] - ) - for n_sigma, alpha in self.options.plot_sigma: - self.drawer.filled_y_area( - x, - y_interp + n_sigma * y_interp_err, - y_interp - n_sigma * y_interp_err, - name=ser, - alpha=alpha, - zorder=5, - ) + if x is not None and y is not None: + self.drawer.line(x, y, name=ser, zorder=3) + if self.data_exists_for(ser, ["y_interp_err"]): + if (y_err := self.data_for(ser, ["y_interp_err"])[0]) is not None: + y_err = np.array(y_err, dtype=float) + for n_sigma, alpha in self.options.plot_sigma: + self.drawer.filled_y_area( + x, + y + n_sigma * y_err, + y - n_sigma * y_err, + name=ser, + alpha=alpha, + zorder=5, + ) # Fit report report = self._write_report() diff --git a/test/visualization/test_plotter.py b/test/visualization/test_plotter.py index 4ca71e990c..070794ef3b 100644 --- a/test/visualization/test_plotter.py +++ b/test/visualization/test_plotter.py @@ -12,9 +12,15 @@ """ Test integration of plotter. """ +from test.base import QiskitExperimentsTestCase +from test.visualization.test_utils import LoggingTestCase from copy import deepcopy -from test.base import QiskitExperimentsTestCase + +import numpy as np +from uncertainties import ufloat +from qiskit_experiments.framework import AnalysisResultData +from qiskit_experiments.visualization import CurvePlotter, MplDrawer from .mock_drawer import MockDrawer from .mock_plotter import MockPlotter @@ -85,3 +91,157 @@ def test_supplementary_data_end_to_end(self): msg=f"Actual figure data value for {key} data-key is not as expected: {actual_value} " f"(actual) vs {expected_value} (expected)", ) + + +class TestCurvePlotter(LoggingTestCase): + """Test case for Qiskit Experiments curve plotter based on logging.""" + + def test_all_data(self): + """Visualize all curve information.""" + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test", + x=[0, 1], + y=[1, 1], + x_formatted=[2, 3], + y_formatted=[2, 2], + y_formatted_err=[0.1, 0.1], + x_interp=[4, 5], + y_interp=[3, 3], + y_interp_err=[0.2, 0.2], + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[2, 3], y_data=[2, 2], x_err=None, y_err=[0.1, 0.1], " + "name='test', label=None, legend=True, options={'zorder': 2}", + "Calling scatter with x_data=[0, 1], y_data=[1, 1], x_err=None, y_err=None, " + "name='test', label=None, legend=False, options={'zorder': 1, 'color': 'gray'}", + "Calling line with x_data=[4, 5], y_data=[3, 3], " + "name='test', label=None, legend=False, options={'zorder': 3}", + "Calling filled_y_area with x_data=[4, 5], y_ub=[3.2, 3.2], y_lb=[2.8, 2.8], " + "name='test', label=None, legend=False, options={'alpha': 0.7, 'zorder': 5}", + "Calling filled_y_area with x_data=[4, 5], y_ub=[3.6, 3.6], y_lb=[2.4, 2.4], " + "name='test', label=None, legend=False, options={'alpha': 0.3, 'zorder': 5}", + "Calling format_canvas", + ], + ) + + def test_supplementary(self): + """Visualize with fitting report.""" + test_result = AnalysisResultData(name="test", value=ufloat(1, 0.2)) + + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test", + x=[0, 1], + y=[1, 1], + ) + plotter.set_supplementary_data( + fit_red_chi=3.0, + primary_results=[test_result], + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[0, 1], y_data=[1, 1], x_err=None, y_err=None, " + "name='test', label=None, legend=True, options={'zorder': 1}", + r"Calling textbox with description='test = 1 ± 0.2\n" + r"reduced-$\chi^2$ = 3', rel_pos=None, options={}", + "Calling format_canvas", + ], + ) + + def test_fit_y_error_missing(self): + """Visualize curve that fitting doesn't work well, i.e. cov-matrix diverges.""" + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test", + x=[0, 1], + y=[1, 1], + x_formatted=[2, 3], + y_formatted=[2, 2], + y_formatted_err=[0.1, 0.1], + x_interp=[4, 5], + y_interp=[3, 3], # y_interp_err is gone + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[2, 3], y_data=[2, 2], x_err=None, y_err=[0.1, 0.1], " + "name='test', label=None, legend=True, options={'zorder': 2}", + "Calling scatter with x_data=[0, 1], y_data=[1, 1], x_err=None, y_err=None, " + "name='test', label=None, legend=False, options={'zorder': 1, 'color': 'gray'}", + "Calling line with x_data=[4, 5], y_data=[3, 3], " + "name='test', label=None, legend=False, options={'zorder': 3}", + "Calling format_canvas", + ], + ) + + def test_fit_fails(self): + """Visualize curve only contains formatted data, i.e. fit completely fails.""" + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test", + x_formatted=[2, 3], + y_formatted=[2, 2], + y_formatted_err=[0.1, 0.1], + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[2, 3], y_data=[2, 2], x_err=None, y_err=[0.1, 0.1], " + "name='test', label=None, legend=True, options={'zorder': 2}", + "Calling format_canvas", + ], + ) + + def test_two_series(self): + """Visualize curve with two series.""" + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test1", + x_formatted=[2, 3], + y_formatted=[2, 2], + y_formatted_err=[0.1, 0.1], + ) + plotter.set_series_data( + series_name="test2", + x_formatted=[2, 3], + y_formatted=[4, 4], + y_formatted_err=[0.2, 0.2], + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[2, 3], y_data=[2, 2], x_err=None, y_err=[0.1, 0.1], " + "name='test1', label=None, legend=True, options={'zorder': 2}", + "Calling scatter with x_data=[2, 3], y_data=[4, 4], x_err=None, y_err=[0.2, 0.2], " + "name='test2', label=None, legend=True, options={'zorder': 2}", + "Calling format_canvas", + ], + ) + + def test_scatter_partly_missing(self): + """Visualize curve include some defect.""" + plotter = CurvePlotter(drawer=MplDrawer()) + plotter.set_series_data( + series_name="test", + x_formatted=[2, 3], + y_formatted=[np.nan, 2], + y_formatted_err=[np.nan, 0.1], + ) + self.assertDrawerAPICallEqual( + plotter, + expected=[ + "Calling initialize_canvas", + "Calling scatter with x_data=[2, 3], y_data=[nan, 2], x_err=None, y_err=[nan, 0.1], " + "name='test', label=None, legend=True, options={'zorder': 2}", + "Calling format_canvas", + ], + ) diff --git a/test/visualization/test_utils.py b/test/visualization/test_utils.py index 9e8a593f52..a9db9f861c 100644 --- a/test/visualization/test_utils.py +++ b/test/visualization/test_utils.py @@ -14,6 +14,9 @@ """ import itertools as it +import logging +import inspect + from test.base import QiskitExperimentsTestCase from typing import List, Tuple @@ -21,6 +24,7 @@ from ddt import data, ddt from qiskit.exceptions import QiskitError +from qiskit_experiments.visualization import BasePlotter, BaseDrawer from qiskit_experiments.visualization.utils import DataExtentCalculator from qiskit_experiments.framework.package_deps import numpy_version @@ -119,3 +123,80 @@ def test_no_data_error(self): ext_calc = DataExtentCalculator() with self.assertRaises(QiskitError): ext_calc.extent() + + +class LoggingTestCase(QiskitExperimentsTestCase): + """Experiments visualization test case for integration test. + + This test case provides a test function assertDrawerAPICallEqual that embeds + a local logger to record internal drawer API calls, + instead of validating the plotter by actually comparing the generated image file. + """ + + class LoggingWrapper: + """Internal drawer wrapper for logging API call.""" + + def __init__( + self, + target: BaseDrawer, + to_record: list[str], + ): + self._target = target + self._to_record = to_record + self._logger = logging.getLogger("LocalLogger") + + def __getattr__(self, name): + method = getattr(self._target, name) + if name in self._to_record: + return self._record_call(method) + return method + + def _record_call(self, method): + """A drawer's method wrapper to record the call details.""" + signature = inspect.signature(method) + + def _format_arg(key_value_tuple): + key, value = key_value_tuple + # to make uniform representation for the ease of unittest + if isinstance(value, (tuple, np.ndarray)): + return f"{key}={list(value)}" + if isinstance(value, str): + return f"{key}='{value}'" + return f"{key}={value}" + + def _wrapped(*args, **kwargs): + full_args = signature.bind(*args, **kwargs) + full_args.apply_defaults() + msg = f"Calling {method.__name__}" + if log_kwargs := ", ".join(map(_format_arg, full_args.arguments.items())): + msg += f" with {log_kwargs}" + self._logger.info(msg) + return method(*args, **kwargs) + + return _wrapped + + def assertDrawerAPICallEqual( + self, + plotter: BasePlotter, + expected: list[str], + ): + """Test if drawer APIs are called with expected arguments via the plotter.figure() call.""" + plotter.drawer = LoggingTestCase.LoggingWrapper( + plotter.drawer, + [ + "line", + "scatter", + "hline", + "filled_y_area", + "filled_x_area", + "textbox", + "image", + "initialize_canvas", + "format_canvas", + ], + ) + logger = "LocalLogger" + with self.assertLogs(logger, level="INFO") as cm: + plotter.figure() + + self.assertListEqual([record.message for record in cm.records], expected) From 5114dc36122143ff6794dfcc0ba80052daad6091 Mon Sep 17 00:00:00 2001 From: Naoki Kanazawa Date: Thu, 22 Feb 2024 00:54:42 +0900 Subject: [PATCH 2/4] Add release note --- ...curve-plotter-missing-interp-y-err-4d7b2ab4611603d0.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 releasenotes/notes/fix-curve-plotter-missing-interp-y-err-4d7b2ab4611603d0.yaml diff --git a/releasenotes/notes/fix-curve-plotter-missing-interp-y-err-4d7b2ab4611603d0.yaml b/releasenotes/notes/fix-curve-plotter-missing-interp-y-err-4d7b2ab4611603d0.yaml new file mode 100644 index 0000000000..c4721926b4 --- /dev/null +++ b/releasenotes/notes/fix-curve-plotter-missing-interp-y-err-4d7b2ab4611603d0.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Fixed a bug that crashes the curve analysis when the covariance matrix in + the least square fit diverges. This bug is caused by lacking exception handling + in the :class:`.CurvePlotter`. See qiskit-experiments/#1413 for details. From 77c7a6b751a3ad6269fec6446cadf47416dcbdee Mon Sep 17 00:00:00 2001 From: Naoki Kanazawa Date: Thu, 22 Feb 2024 01:02:20 +0900 Subject: [PATCH 3/4] fix typehint --- test/visualization/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/visualization/test_utils.py b/test/visualization/test_utils.py index a9db9f861c..94208ec71b 100644 --- a/test/visualization/test_utils.py +++ b/test/visualization/test_utils.py @@ -12,6 +12,7 @@ """ Test visualization utilities. """ +from __future__ import annotations import itertools as it import logging From 4253d2e0b1fd54993dddb58584986b9b9406b5da Mon Sep 17 00:00:00 2001 From: Naoki Kanazawa Date: Thu, 22 Feb 2024 02:00:22 +0900 Subject: [PATCH 4/4] disable lint --- test/visualization/test_plotter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/visualization/test_plotter.py b/test/visualization/test_plotter.py index 070794ef3b..cc87d39109 100644 --- a/test/visualization/test_plotter.py +++ b/test/visualization/test_plotter.py @@ -142,14 +142,15 @@ def test_supplementary(self): fit_red_chi=3.0, primary_results=[test_result], ) + # pylint: disable=anomalous-backslash-in-string self.assertDrawerAPICallEqual( plotter, expected=[ "Calling initialize_canvas", "Calling scatter with x_data=[0, 1], y_data=[1, 1], x_err=None, y_err=None, " "name='test', label=None, legend=True, options={'zorder': 1}", - r"Calling textbox with description='test = 1 ± 0.2\n" - r"reduced-$\chi^2$ = 3', rel_pos=None, options={}", + "Calling textbox with description='test = 1 ± 0.2\n" + "reduced-$\chi^2$ = 3', rel_pos=None, options={}", "Calling format_canvas", ], )