diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 6e6a9f28f67..d262212cb10 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -380,6 +380,7 @@ merge_operations_to_circuit_op as merge_operations_to_circuit_op, merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, + merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, optimize_for_target_gateset as optimize_for_target_gateset, parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 826798f6697..e3d1c9a0d35 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -100,6 +100,7 @@ merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, + merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized, ) from cirq.transformers.qubit_management_transformers import ( diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 795856f2b12..c7a26abfb64 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -16,13 +16,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Callable, cast, Hashable, TYPE_CHECKING from cirq import circuits, ops, protocols -from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives +from cirq.study.resolver import ParamResolver +from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip +from cirq.transformers import ( + align, + merge_k_qubit_gates, + symbolize, + tag_transformers, + transformer_api, + transformer_primitives, +) from cirq.transformers.analytical_decompositions import single_qubit_decompositions if TYPE_CHECKING: + import sympy + import cirq @@ -67,6 +78,7 @@ def merge_single_qubit_gates_to_phxz( circuit: cirq.AbstractCircuit, *, context: cirq.TransformerContext | None = None, + merge_tags_fn: Callable[[cirq.CircuitOperation], list[Hashable]] | None = None, atol: float = 1e-8, ) -> cirq.Circuit: """Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`. @@ -77,6 +89,7 @@ def merge_single_qubit_gates_to_phxz( Args: circuit: Input circuit to transform. It will not be modified. context: `cirq.TransformerContext` storing common configurable options for transformers. + merge_tags_fn: A callable returns the tags to be added to the merged operation. atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be dropped, smaller values increase accuracy. @@ -84,12 +97,13 @@ def merge_single_qubit_gates_to_phxz( Copy of the transformed input circuit. """ - def rewriter(op: cirq.CircuitOperation) -> cirq.OP_TREE: - u = protocols.unitary(op) - if protocols.num_qubits(op) == 0: + def rewriter(circuit_op: cirq.CircuitOperation) -> cirq.OP_TREE: + u = protocols.unitary(circuit_op) + if protocols.num_qubits(circuit_op) == 0: return ops.GlobalPhaseGate(u[0, 0]).on() - gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) - return gate(op.qubits[0]) if gate else [] + gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I + phxz_op = gate.on(circuit_op.qubits[0]) + return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op return merge_k_qubit_gates.merge_k_qubit_unitaries( circuit, k=1, context=context, rewriter=rewriter @@ -158,3 +172,160 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None: deep=context.deep if context else False, tags_to_ignore=tuple(tags_to_ignore), ).unfreeze(copy=False) + + +def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep: + new_resolvers: list[cirq.ParamResolver] = [] + for resolver in sweep: + param_dict: cirq.ParamMappingType = {s: resolver.value_of(s) for s in symbols} + new_resolvers.append(ParamResolver(param_dict)) + return ListSweep(new_resolvers) + + +def _calc_phxz_sweeps( + symbolized_circuit: cirq.Circuit, resolved_circuits: list[cirq.Circuit] +) -> Sweep: + """Return the phxz sweep of the symbolized_circuit on resolved_circuits. + + Raises: + ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type. + Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a + symbolic `PhasedXZGate` in the `symbolized_circuit`. + """ + + def _extract_axz(op: ops.Operation | None) -> tuple[float, float, float]: + if not op or not op.gate or not isinstance(op.gate, ops.IdentityGate | ops.PhasedXZGate): + raise ValueError(f"Expect a PhasedXZGate or IdentityGate on op {op}.") + if isinstance(op.gate, ops.IdentityGate): + return 0.0, 0.0, 0.0 # Identity gate's a, x, z in PhasedXZ + return op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent + + values_by_params: dict[sympy.Symbol, tuple[float, ...]] = {} + for mid, moment in enumerate(symbolized_circuit): + for op in moment.operations: + if op.gate and isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op): + sa, sx, sz = op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent + values_by_params[sa], values_by_params[sx], values_by_params[sz] = zip( + *[_extract_axz(c[mid].operation_at(op.qubits[0])) for c in resolved_circuits] + ) + + return dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)) + + +def merge_single_qubit_gates_to_phxz_symbolized( + circuit: cirq.AbstractCircuit, + *, + context: cirq.TransformerContext | None = None, + sweep: Sweep, + atol: float = 1e-8, +) -> tuple[cirq.Circuit, Sweep]: + """Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of + the consecutive gates is symbolized. + + Example: + >>> q0, q1 = cirq.LineQubit.range(2) + >>> c = cirq.Circuit(\ + cirq.X(q0),\ + cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\ + cirq.Y(q0)**sympy.Symbol("y_exp"),\ + cirq.X(q0)) + >>> print(c) + 0: ───X───@──────────Y^y_exp───X─── + │ + 1: ───────@^cz_exp───────────────── + >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\ + c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\ + cirq.Points(key="y_exp", points=[0, 1]))) + >>> print(new_circuit) + 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)─── + │ + 1: ────────────────────────@^cz_exp────────────────────────── + >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0}) + >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1}) + + Args: + circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned + based on the transformation. + atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be + dropped, smaller values increase accuracy. + + Returns: + Copy of the transformed input circuit. + """ + deep = context.deep if context else False + + # Tag symbolized single-qubit op. + symbolized_single_tag = "_tmp_symbolize_tag" + + circuit_tagged = transformer_primitives.map_operations( + circuit, + lambda op, _: ( + op.with_tags(symbolized_single_tag) + if protocols.is_parameterized(op) and len(op.qubits) == 1 + else op + ), + deep=deep, + ) + + # Step 0, isolate single qubit symbols and resolve the circuit on them. + single_qubit_gate_symbols: set[sympy.Symbol] = set().union( + *[ + protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() + for op in circuit_tagged.all_operations() + ] + ) + # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. + remaining_symbols: set[sympy.Symbol] = set( + protocols.parameter_symbols(circuit) - single_qubit_gate_symbols + ) + # If all single qubit gates are not parameterized, call the nonparamerized version of + # the transformer. + if not single_qubit_gate_symbols: + return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) + sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols) + # Get all resolved circuits from all sets of resolvers in sweep_of_single. + resolved_circuits = [ + protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single + ] + + # Step 1, merge single qubit gates per resolved circuit, preserving + # the symbolized_single_tag to indicate the operator is a merged one. + merged_circuits: list[cirq.Circuit] = [ + merge_single_qubit_gates_to_phxz( + c, + context=context, + merge_tags_fn=lambda circuit_op: ( + [symbolized_single_tag] + if any( + symbolized_single_tag in set(op.tags) + for op in circuit_op.circuit.all_operations() + ) + else [] + ), + atol=atol, + ) + for c in resolved_circuits + ] + + # Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag. + new_circuit = tag_transformers.remove_tags( # remove the temp tags used to track merges + symbolize.symbolize_single_qubit_gates_by_indexed_tags( + tag_transformers.index_tags( # index all 1-qubit-ops merged from ops with symbols + merged_circuits[0], + context=transformer_api.TransformerContext(deep=deep), + target_tags={symbolized_single_tag}, + ), + symbolize_tag=symbolize.SymbolizeTag(prefix=symbolized_single_tag), + ), + remove_if=lambda tag: str(tag).startswith(symbolized_single_tag), + ) + + # Step 3, get N sets of parameterizations as new_sweep. + new_sweep = Zip( + _calc_phxz_sweeps(new_circuit, merged_circuits), # phxz sweeps + _sweep_on_symbols(sweep, remaining_symbols), # remaining sweeps + ) + + return align.align_right(new_circuit), new_sweep diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 8c22a40a4aa..dec9a85dc6e 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +from unittest import TestCase +from unittest.mock import Mock, patch + +import pytest +import sympy import cirq @@ -81,7 +85,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z_deep(): cirq.testing.assert_same_circuits(c_new, c_expected) -def _phxz(a: float, x: float, z: float): +def _phxz(a: float | sympy.Symbol, x: float | sympy.Symbol, z: float | sympy.Symbol): return cirq.PhasedXZGate(axis_phase_exponent=a, x_exponent=x, z_exponent=z) @@ -233,6 +237,142 @@ def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase(): assert c == c2 +class TestMergeSingleQubitGatesSymbolized(TestCase): + """Test suite for merge_single_qubit_gates_to_phxz_symbolized.""" + + def test_case1(self): + """Test case diagram. + Input circuit: + 0: ───X─────────@────────H[ignore]─H──X──PhXZ(a=a0,x=x0,z=z0)──X──PhXZ(a=a1,x=x1,z=z1)─── + │ + 1: ───H^h_exp───@^cz_exp───────────────────────────────────────────────────────────────── + Expected output: + 0: ───PhXZ(a=-1,x=1,z=0)─────@──────────H[ignore]───PhXZ(a=a1,x=x1,z=z1)─── + │ + 1: ───PhXZ(a=a0,x=x0,z=z0)───@^cz_exp────────────────────────────────────── + """ + a, b = cirq.LineQubit.range(2) + sa0, sa1 = [sympy.Symbol(a) for a in ["a0", "a1"]] + sx0, sx1 = [sympy.Symbol(x) for x in ["x0", "x1"]] + sz0, sz1 = [sympy.Symbol(z) for z in ["z0", "z1"]] + input_circuit = cirq.Circuit( + cirq.Moment(cirq.X(a), cirq.H(b) ** sympy.Symbol("h_exp")), + cirq.Moment(cirq.CZ(a, b) ** sympy.Symbol("cz_exp")), + cirq.Moment(cirq.H(a).with_tags("ignore")), + cirq.Moment(cirq.H(a)), + cirq.Moment(cirq.X(a)), + cirq.Moment(_phxz(sa0, sx0, sz0).on(a)), + cirq.Moment(cirq.X(a)), + cirq.Moment(_phxz(sa1, sx1, sz1).on(a)), + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + sweep = cirq.Zip( + cirq.Points(key="h_exp", points=[0, 1]), + cirq.Points(key="cz_exp", points=[0, 1]), + cirq.Points(key="a0", points=[0, 1]), + cirq.Points(key="x0", points=[0, 1]), + cirq.Points(key="z0", points=[0, 1]), + cirq.Points(key="a1", points=[0, 1]), + cirq.Points(key="x1", points=[0, 1]), + cirq.Points(key="z1", points=[0, 1]), + ) + output_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized( + input_circuit, context=context, sweep=sweep + ) + expected = cirq.Circuit( + cirq.Moment(_phxz(-1, 1, 0).on(a), _phxz(sa0, sx0, sz0).on(b)), + cirq.Moment(cirq.CZ(a, b) ** sympy.Symbol("cz_exp")), + cirq.Moment(cirq.H(a).with_tags("ignore")), + cirq.Moment(_phxz(sa1, sx1, sz1).on(a)), + ) + assert_optimizes(output_circuit, expected) + + # Check the unitaries are preserved for each set of sweep paramerization. + for old_resolver, new_resolver in zip(sweep, new_sweep): + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + cirq.resolve_parameters(input_circuit, old_resolver), + cirq.resolve_parameters(output_circuit, new_resolver), + {q: q for q in input_circuit.all_qubits()}, + ) + + def test_with_gauge_compiling_as_sweep_success(self): + qubits = cirq.LineQubit.range(7) + c = cirq.Circuit( + cirq.Moment(cirq.H(qubits[0]), cirq.H(qubits[3])), + cirq.Moment(cirq.CZ(qubits[0], qubits[2]), cirq.CZ(qubits[3], qubits[5])), + cirq.Moment(cirq.CZ(qubits[0], qubits[1]), cirq.CZ(qubits[3], qubits[4])), + cirq.Moment(cirq.CZ(qubits[1], qubits[3]), cirq.CZ(qubits[4], qubits[6])), + cirq.Moment(cirq.M(*qubits, key='m')), + ) + old_circuit, old_sweep = cirq.transformers.gauge_compiling.CZGaugeTransformer.as_sweep( + c, N=50 + ) + new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized( + old_circuit, sweep=old_sweep + ) + # Check the unitaries are preserved for each set of sweep paramerization. + for old_resolver, new_resolver in zip(old_sweep, new_sweep): + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + cirq.resolve_parameters(old_circuit[0:-1], old_resolver), + cirq.resolve_parameters(new_circuit[0:-1], new_resolver), + {q: q for q in qubits}, + ) + + def test_case_non_parameterized_singles(self): + """Test merge_single_qubit_gates_to_phxz_symbolized when all single qubit gates are not + parameterized.""" + + a, b = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit(cirq.H(a), cirq.H(a), cirq.CZ(a, b) ** sympy.Symbol("exp")) + expected_circuit = cirq.merge_single_qubit_gates_to_phxz(input_circuit) + output_circuit, _ = cirq.merge_single_qubit_gates_to_phxz_symbolized( + input_circuit, sweep=cirq.Points(key="exp", points=[0.1, 0.2, 0.5]) + ) + assert_optimizes(output_circuit, expected_circuit) + + def test_fail_different_structures_error(self): + """Tests that the function raises a ValueError if merged structures of the circuit differ + for different parameterizations.""" + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.H(q0) ** sympy.Symbol("exp")) + sweep = cirq.Points(key="exp", points=[0.1, 0.2]) + + with patch( + "cirq.protocols.resolve_parameters", + side_effect=[ # Mock the return values of resolve_parameters + cirq.Circuit(cirq.I(q0).with_tags("_tmp_symbolize_tag")), + cirq.Circuit(cirq.CZ(q0, q1)), + ], + ): + with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"): + cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) + + def test_fail_unexpected_gate_error(self): + """Tests that the function raises a RuntimeError of unexpected gate.""" + a, b = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.H(a) ** sympy.Symbol("exp1"), + cirq.X(a), + cirq.CZ(a, b), + cirq.Y(a), + cirq.H(a) ** sympy.Symbol("exp2"), + ) + sweep = cirq.Points(key="exp1", points=[0.1, 0.2]) * cirq.Points( + key="exp2", points=[0.1, 0.2] + ) + + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=2) + + with patch( + "cirq.transformers.analytical_decompositions" + ".single_qubit_decompositions.single_qubit_matrix_to_phxz", + return_value=cirq.H, + ): + with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"): + cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) + + def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_first_moment(): q0 = cirq.LineQubit(0) c_orig = cirq.Circuit(