diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py index 66c11aacad6..0928db748f3 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py @@ -28,7 +28,9 @@ from cirq.experiments.readout_confusion_matrix import TensoredConfusionMatrices if TYPE_CHECKING: - from cirq.experiments import SingleQubitReadoutCalibrationResult + from cirq.experiments.single_qubit_readout_calibration import ( + SingleQubitReadoutCalibrationResult, + ) from cirq.study import ResultDict @@ -217,6 +219,11 @@ def _normalize_input_paulis( return cast(dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], circuits_to_pauli) +def _extract_readout_qubits(pauli_strings: list[ops.PauliString]) -> list[ops.Qid]: + """Extracts unique qubits from a list of QWC Pauli strings.""" + return sorted(set(q for ps in pauli_strings for q in ps.qubits)) + + def _pauli_strings_to_basis_change_ops( pauli_strings: list[ops.PauliString], qid_list: list[ops.Qid] ): @@ -315,16 +322,38 @@ def _process_pauli_measurement_results( for pauli_group_index, circuit_result in enumerate(circuit_results): measurement_results = circuit_result.measurements["m"] pauli_strs = pauli_string_groups[pauli_group_index] + pauli_readout_qubits = _extract_readout_qubits(pauli_strs) + + calibration_result = ( + calibration_results[tuple(pauli_readout_qubits)] + if disable_readout_mitigation is False + else None + ) for pauli_str in pauli_strs: qubits_sorted = sorted(pauli_str.qubits) qubit_indices = [qubits.index(q) for q in qubits_sorted] - confusion_matrices = ( - _build_many_one_qubits_confusion_matrix(calibration_results[tuple(qubits_sorted)]) - if disable_readout_mitigation is False - else _build_many_one_qubits_empty_confusion_matrix(len(qubits_sorted)) - ) + if disable_readout_mitigation: + pauli_str_calibration_result = None + confusion_matrices = _build_many_one_qubits_empty_confusion_matrix( + len(qubits_sorted) + ) + else: + if calibration_result is None: + # This case should be logically impossible if mitigation is on, + # so we raise an error. + raise ValueError( + f"Readout mitigation is enabled, but no calibration result was " + f"found for qubits {pauli_readout_qubits}." + ) + pauli_str_calibration_result = calibration_result.readout_result_for_qubits( + qubits_sorted + ) + confusion_matrices = _build_many_one_qubits_confusion_matrix( + pauli_str_calibration_result + ) + tensored_cm = TensoredConfusionMatrices( confusion_matrices, [[q] for q in qubits_sorted], @@ -356,11 +385,7 @@ def _process_pauli_measurement_results( mitigated_stddev=d_m_with_coefficient, unmitigated_expectation=unmitigated_value_with_coefficient, unmitigated_stddev=d_unmit_with_coefficient, - calibration_result=( - calibration_results[tuple(qubits_sorted)] - if disable_readout_mitigation is False - else None - ), + calibration_result=pauli_str_calibration_result, ) ) @@ -428,8 +453,7 @@ def measure_pauli_strings( unique_qubit_tuples = set() for pauli_string_groups in normalized_circuits_to_pauli.values(): for pauli_strings in pauli_string_groups: - for pauli_string in pauli_strings: - unique_qubit_tuples.add(tuple(sorted(pauli_string.qubits))) + unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings))) # qubits_list is a list of qubit tuples qubits_list = sorted(unique_qubit_tuples) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py index 759872a7a57..1992ead250d 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py @@ -23,7 +23,10 @@ import cirq from cirq.contrib.paulistring import measure_pauli_strings -from cirq.experiments import SingleQubitReadoutCalibrationResult +from cirq.contrib.paulistring.pauli_string_measurement_with_readout_mitigation import ( + _process_pauli_measurement_results, +) +from cirq.experiments.single_qubit_readout_calibration import SingleQubitReadoutCalibrationResult from cirq.experiments.single_qubit_readout_calibration_test import NoisySingleQubitReadoutSampler @@ -867,3 +870,37 @@ def test_group_paulis_type_mismatch() -> None: measure_pauli_strings( circuits_to_pauli, cirq.Simulator(), 1000, 1000, 1000, np.random.default_rng() ) + + +def test_process_pauli_measurement_results_raises_error_on_missing_calibration() -> None: + """Test that the function raises an error if the calibration result is missing.""" + qubits: list[cirq.Qid] = [q for q in cirq.LineQubit.range(5)] + + measurement_op = cirq.measure(*qubits, key='m') + test_circuits = list[cirq.Circuit]() + for _ in range(3): + circuit_list = [] + + circuit = _create_ghz(5, qubits) + measurement_op + circuit_list.append(circuit) + test_circuits.extend(circuit_list) + + pauli_strings = [_generate_random_pauli_string(qubits, True) for _ in range(3)] + sampler = cirq.Simulator() + + circuit_results = sampler.run_batch(test_circuits, repetitions=1000) + + empty_calibration_result_dict = {tuple(qubits): None} + + with pytest.raises( + ValueError, + match="Readout mitigation is enabled, but no calibration result was found for qubits", + ): + _process_pauli_measurement_results( + qubits, + [pauli_strings], + circuit_results[0], # type: ignore[arg-type] + empty_calibration_result_dict, # type: ignore[arg-type] + 1000, + 1.0, + ) diff --git a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py index 5c0378ee6db..eaa1ce55b41 100644 --- a/cirq-core/cirq/experiments/single_qubit_readout_calibration.py +++ b/cirq-core/cirq/experiments/single_qubit_readout_calibration.py @@ -179,6 +179,17 @@ def plot_integrated_histogram( ax.set_ylabel('Percentile') return ax + def readout_result_for_qubits( + self, readout_qubits: list[ops.Qid] + ) -> SingleQubitReadoutCalibrationResult: + """Builds a calibration result for the specific readout qubits.""" + return SingleQubitReadoutCalibrationResult( + zero_state_errors={qubit: self.zero_state_errors[qubit] for qubit in readout_qubits}, + one_state_errors={qubit: self.one_state_errors[qubit] for qubit in readout_qubits}, + timestamp=self.timestamp, + repetitions=self.repetitions, + ) + @classmethod def _from_json_dict_( cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs