Skip to content

Combining readout circuits for QWC pauli strings in the measure_pauli_strings method #7416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from cirq.experiments.readout_confusion_matrix import TensoredConfusionMatrices

if TYPE_CHECKING:
from cirq.experiments import SingleQubitReadoutCalibrationResult
from cirq.experiments.single_qubit_readout_calibration import (
SingleQubitReadoutCalibrationResult,
)
from cirq.study import ResultDict


Expand Down Expand Up @@ -217,6 +219,11 @@ def _normalize_input_paulis(
return cast(dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], circuits_to_pauli)


def _extract_readout_qubits(pauli_strings: list[ops.PauliString]) -> list[ops.Qid]:
"""Extracts unique qubits from a list of QWC Pauli strings."""
return sorted(set(q for ps in pauli_strings for q in ps.qubits))


def _pauli_strings_to_basis_change_ops(
pauli_strings: list[ops.PauliString], qid_list: list[ops.Qid]
):
Expand Down Expand Up @@ -315,16 +322,38 @@ def _process_pauli_measurement_results(
for pauli_group_index, circuit_result in enumerate(circuit_results):
measurement_results = circuit_result.measurements["m"]
pauli_strs = pauli_string_groups[pauli_group_index]
pauli_readout_qubits = _extract_readout_qubits(pauli_strs)

calibration_result = (
calibration_results[tuple(pauli_readout_qubits)]
if disable_readout_mitigation is False
else None
)

for pauli_str in pauli_strs:
qubits_sorted = sorted(pauli_str.qubits)
qubit_indices = [qubits.index(q) for q in qubits_sorted]

confusion_matrices = (
_build_many_one_qubits_confusion_matrix(calibration_results[tuple(qubits_sorted)])
if disable_readout_mitigation is False
else _build_many_one_qubits_empty_confusion_matrix(len(qubits_sorted))
)
if disable_readout_mitigation:
pauli_str_calibration_result = None
confusion_matrices = _build_many_one_qubits_empty_confusion_matrix(
len(qubits_sorted)
)
else:
if calibration_result is None:
# This case should be logically impossible if mitigation is on,
# so we raise an error.
raise ValueError(
f"Readout mitigation is enabled, but no calibration result was "
f"found for qubits {pauli_readout_qubits}."
)
pauli_str_calibration_result = calibration_result.readout_result_for_qubits(
qubits_sorted
)
confusion_matrices = _build_many_one_qubits_confusion_matrix(
pauli_str_calibration_result
)

tensored_cm = TensoredConfusionMatrices(
confusion_matrices,
[[q] for q in qubits_sorted],
Expand Down Expand Up @@ -356,11 +385,7 @@ def _process_pauli_measurement_results(
mitigated_stddev=d_m_with_coefficient,
unmitigated_expectation=unmitigated_value_with_coefficient,
unmitigated_stddev=d_unmit_with_coefficient,
calibration_result=(
calibration_results[tuple(qubits_sorted)]
if disable_readout_mitigation is False
else None
),
calibration_result=pauli_str_calibration_result,
)
)

Expand Down Expand Up @@ -428,8 +453,7 @@ def measure_pauli_strings(
unique_qubit_tuples = set()
for pauli_string_groups in normalized_circuits_to_pauli.values():
for pauli_strings in pauli_string_groups:
for pauli_string in pauli_strings:
unique_qubit_tuples.add(tuple(sorted(pauli_string.qubits)))
unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings)))
# qubits_list is a list of qubit tuples
qubits_list = sorted(unique_qubit_tuples)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import cirq
from cirq.contrib.paulistring import measure_pauli_strings
from cirq.experiments import SingleQubitReadoutCalibrationResult
from cirq.contrib.paulistring.pauli_string_measurement_with_readout_mitigation import (
_process_pauli_measurement_results,
)
from cirq.experiments.single_qubit_readout_calibration import SingleQubitReadoutCalibrationResult
from cirq.experiments.single_qubit_readout_calibration_test import NoisySingleQubitReadoutSampler


Expand Down Expand Up @@ -867,3 +870,37 @@ def test_group_paulis_type_mismatch() -> None:
measure_pauli_strings(
circuits_to_pauli, cirq.Simulator(), 1000, 1000, 1000, np.random.default_rng()
)


def test_process_pauli_measurement_results_raises_error_on_missing_calibration() -> None:
"""Test that the function raises an error if the calibration result is missing."""
qubits: list[cirq.Qid] = [q for q in cirq.LineQubit.range(5)]

measurement_op = cirq.measure(*qubits, key='m')
test_circuits = list[cirq.Circuit]()
for _ in range(3):
circuit_list = []

circuit = _create_ghz(5, qubits) + measurement_op
circuit_list.append(circuit)
test_circuits.extend(circuit_list)

pauli_strings = [_generate_random_pauli_string(qubits, True) for _ in range(3)]
sampler = cirq.Simulator()

circuit_results = sampler.run_batch(test_circuits, repetitions=1000)

empty_calibration_result_dict = {tuple(qubits): None}

with pytest.raises(
ValueError,
match="Readout mitigation is enabled, but no calibration result was found for qubits",
):
_process_pauli_measurement_results(
qubits,
[pauli_strings],
circuit_results[0], # type: ignore[arg-type]
empty_calibration_result_dict, # type: ignore[arg-type]
1000,
1.0,
)
11 changes: 11 additions & 0 deletions cirq-core/cirq/experiments/single_qubit_readout_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ def plot_integrated_histogram(
ax.set_ylabel('Percentile')
return ax

def readout_result_for_qubits(
self, readout_qubits: list[ops.Qid]
) -> SingleQubitReadoutCalibrationResult:
"""Builds a calibration result for the specific readout qubits."""
return SingleQubitReadoutCalibrationResult(
zero_state_errors={qubit: self.zero_state_errors[qubit] for qubit in readout_qubits},
one_state_errors={qubit: self.one_state_errors[qubit] for qubit in readout_qubits},
timestamp=self.timestamp,
repetitions=self.repetitions,
)

@classmethod
def _from_json_dict_(
cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs
Expand Down