From 4f9f30652c31d4d2a3454e6619f998c747d075f5 Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 28 May 2025 10:17:20 -0700 Subject: [PATCH 1/5] Extract global phase from controlled common gates during decomposition --- cirq-core/cirq/ops/common_gates.py | 50 +++++++++++- cirq-core/cirq/ops/controlled_gate.py | 76 ++++++------------- cirq-core/cirq/ops/global_phase_op.py | 19 +++++ .../cirq/protocols/decompose_protocol.py | 7 ++ 4 files changed, 98 insertions(+), 54 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index f2823d2a3bd..105efb114e1 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -37,7 +37,14 @@ from cirq import protocols, value from cirq._compat import proper_repr from cirq._doc import document -from cirq.ops import control_values as cv, controlled_gate, eigen_gate, gate_features, raw_types +from cirq.ops import ( + control_values as cv, + controlled_gate, + eigen_gate, + gate_features, + global_phase_op, + raw_types, +) from cirq.ops.measurement_gate import MeasurementGate from cirq.ops.swap_gates import ISWAP, ISwapPowGate, SWAP, SwapPowGate @@ -235,6 +242,11 @@ def controlled( return cirq.CCXPowGate(exponent=self._exponent) return result + def _decompose_with_context_( + self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext + ) -> list[cirq.Operation] | NotImplementedType: + return _extract_phase(self, XPowGate, qubits, context) + def _pauli_expansion_(self) -> value.LinearDict[str]: if self._dimension != 2: return NotImplemented # pragma: no cover @@ -487,6 +499,11 @@ def __repr__(self) -> str: f'global_shift={self._global_shift!r})' ) + def _decompose_with_context_( + self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext + ) -> list[cirq.Operation] | NotImplementedType: + return _extract_phase(self, YPowGate, qubits, context) + class Ry(YPowGate): r"""A gate with matrix $e^{-i Y t/2}$ that rotates around the Y axis of the Bloch sphere by $t$. @@ -699,6 +716,11 @@ def controlled( return cirq.CCZPowGate(exponent=self._exponent) return result + def _decompose_with_context_( + self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext + ) -> list[cirq.Operation] | NotImplementedType: + return _extract_phase(self, ZPowGate, qubits, context) + def _qid_shape_(self) -> tuple[int, ...]: return (self._dimension,) @@ -1131,6 +1153,11 @@ def controlled( control_qid_shape=result.control_qid_shape + (2,), ) + def _decompose_with_context_( + self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext + ) -> list[cirq.Operation] | NotImplementedType: + return _extract_phase(self, CZPowGate, qubits, context) + def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: return protocols.CircuitDiagramInfo( wire_symbols=('@', '@'), exponent=self._diagram_exponent(args) @@ -1486,3 +1513,24 @@ def _phased_x_or_pauli_gate( case 0.5: return YPowGate(exponent=exponent) return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent) + + +def _extract_phase( + gate: cirq.EigenGate, + gate_class: type, + qubits: tuple[cirq.Qid, ...], + context: cirq.DecompositionContext, +) -> list[cirq.Operation] | NotImplementedType: + """Extracts the global phase field to its own gate, or absorbs it if it has no effect. + + This is for use within the decompose handlers, and will return `NotImplemented` if there is no + global phase, implying it is already in its simplest form. It will return a list, with the + global phase first, and the original op minus any global phase op second. If the resulting + global phase is empty (can happen for example in `XPowGate(global_phase=2/3)**3`), then it is + excluded from the return value.""" + if not context.extract_global_phases or gate.global_shift == 0: + return NotImplemented + phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent) + return ([] if phase_gate.is_identity else [phase_gate()]) + [ + gate_class(exponent=gate.exponent).on(*qubits) + ] diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 4ed6a03b876..9813c67d3ea 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -24,7 +24,6 @@ control_values as cv, controlled_operation as cop, diagonal_gate as dg, - global_phase_op as gp, op_tree, raw_types, ) @@ -139,12 +138,9 @@ def num_controls(self) -> int: def _qid_shape_(self) -> tuple[int, ...]: return self.control_qid_shape + protocols.qid_shape(self.sub_gate) - def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE: - return self._decompose_with_context_(qubits) - def _decompose_with_context_( - self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None - ) -> None | NotImplementedType | cirq.OP_TREE: + self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext + ) -> NotImplementedType | cirq.OP_TREE: control_qubits = list(qubits[: self.num_controls()]) controlled_sub_gate = self.sub_gate.controlled( self.num_controls(), self.control_values, self.control_qid_shape @@ -152,6 +148,25 @@ def _decompose_with_context_( # Prefer the subgate controlled version if available if self != controlled_sub_gate: return controlled_sub_gate.on(*qubits) + + # Try decomposing the subgate next. + # Extract global phases from decomposition, as controlled phases decompose easily. + result = protocols.decompose_once_with_qubits( + self.sub_gate, + qubits[self.num_controls() :], + NotImplemented, + flatten=False, + context=context.extracting_global_phases(), + ) + if result is not NotImplemented: + return op_tree.transform_op_tree( + result, + lambda op: op.controlled_by( + *qubits[: self.num_controls()], control_values=self.control_values + ), + ) + + # Finally try brute-force on the unitary. if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits): n_qubits = protocols.num_qubits(self.sub_gate) # Case 1: Global Phase (1x1 Matrix) @@ -173,54 +188,9 @@ def _decompose_with_context_( protocols.unitary(self.sub_gate), control_qubits, qubits[-1] ) return invert_ops + decomposed_ops + invert_ops - if isinstance(self.sub_gate, common_gates.CZPowGate): - z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) - num_controls = self.num_controls() + 1 - control_values = self.control_values & cv.ProductOfSums(((1,),)) - control_qid_shape = self.control_qid_shape + (2,) - controlled_z = ( - z_sub_gate.controlled( - num_controls=num_controls, - control_values=control_values, - control_qid_shape=control_qid_shape, - ) - if protocols.is_parameterized(self) - else ControlledGate( - z_sub_gate, - num_controls=num_controls, - control_values=control_values, - control_qid_shape=control_qid_shape, - ) - ) - if self != controlled_z: - result = controlled_z.on(*qubits) - if self.sub_gate.global_shift == 0: - return result - # Reconstruct the controlled global shift of the subgate. - total_shift = self.sub_gate.exponent * self.sub_gate.global_shift - phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift)) - controlled_phase_op = phase_gate.controlled( - num_controls=self.num_controls(), - control_values=self.control_values, - control_qid_shape=self.control_qid_shape, - ).on(*control_qubits) - return [result, controlled_phase_op] - result = protocols.decompose_once_with_qubits( - self.sub_gate, - qubits[self.num_controls() :], - NotImplemented, - flatten=False, - context=context, - ) - if result is NotImplemented: - return NotImplemented - return op_tree.transform_op_tree( - result, - lambda op: op.controlled_by( - *qubits[: self.num_controls()], control_values=self.control_values - ), - ) + # If nothing works, return `NotImplemented`. + return NotImplemented def on(self, *qubits: cirq.Qid) -> cop.ControlledOperation: if len(qubits) == 0: diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index bef5789ed1d..4b997627c1d 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -92,6 +92,12 @@ def _resolve_parameters_( coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) return GlobalPhaseGate(coefficient=coefficient) + @property + def is_identity(self) -> bool: + return ( + not protocols.is_parameterized(self._coefficient) and abs(self._coefficient - 1.0) == 0 + ) + def controlled( self, num_controls: int | None = None, @@ -122,3 +128,16 @@ def global_phase_operation( ) -> cirq.GateOperation: """Creates an operation that represents a global phase on the state.""" return GlobalPhaseGate(coefficient, atol)() + + +def from_phase_and_exponent( + half_turns: 'cirq.TParamVal', exponent: 'cirq.TParamVal' +) -> 'cirq.GlobalPhaseGate': + """Creates a GlobalPhaseGate from the global phase and exponent.""" + coefficient = 1j ** (2 * half_turns * exponent) + coefficient = ( + complex(coefficient) + if isinstance(coefficient, sympy.Expr) and coefficient.is_complex + else coefficient + ) + return GlobalPhaseGate(coefficient) diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index 3d552df311f..369105ad5bf 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -81,9 +81,16 @@ class DecompositionContext: Args: qubit_manager: A `cirq.QubitManager` instance to allocate clean / dirty ancilla qubits as part of the decompose protocol. + extract_global_phases: If set, will extract the global phases from + `DECOMPOSE_TARGET_GATESET` into independent global phase operations. """ qubit_manager: cirq.QubitManager + extract_global_phases: bool = False + + def extracting_global_phases(self) -> DecompositionContext: + """Returns a copy with the `extract_global_phases` field set.""" + return dataclasses.replace(self, extract_global_phases=True) class SupportsDecompose(Protocol): From bc34fbbce73d89fd4b1e6acc4afb838a539a89d1 Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 28 May 2025 16:05:06 -0700 Subject: [PATCH 2/5] Add tests --- cirq-core/cirq/ops/common_gates.py | 9 +++--- cirq-core/cirq/ops/common_gates_test.py | 29 +++++++++++++++++++ cirq-core/cirq/ops/controlled_gate_test.py | 16 ++++++++++ cirq-core/cirq/ops/global_phase_op_test.py | 20 +++++++++++++ .../cirq/protocols/decompose_protocol_test.py | 9 ++++++ 5 files changed, 79 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 105efb114e1..0ec68275122 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -1525,12 +1525,13 @@ def _extract_phase( This is for use within the decompose handlers, and will return `NotImplemented` if there is no global phase, implying it is already in its simplest form. It will return a list, with the - global phase first, and the original op minus any global phase op second. If the resulting + original op minus any global phase first, and the global phase op second. If the resulting global phase is empty (can happen for example in `XPowGate(global_phase=2/3)**3`), then it is excluded from the return value.""" if not context.extract_global_phases or gate.global_shift == 0: return NotImplemented + result = [gate_class(exponent=gate.exponent).on(*qubits)] phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent) - return ([] if phase_gate.is_identity else [phase_gate()]) + [ - gate_class(exponent=gate.exponent).on(*qubits) - ] + if not phase_gate.is_identity: + result.append(phase_gate()) + return result diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index d615ddb29b1..89a0a9a625f 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -1322,3 +1322,32 @@ def test_parameterized_pauli_expansion(gate_type, exponent): gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5}) pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5}) assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved)) + + +@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate]) +@pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')]) +def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None: + context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True) + gate = gate_type(exponent=exponent, global_shift=2 / 3) + op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate))) + decomposed = cirq.decompose(op, context=context) + gate0 = decomposed[0].gate + assert isinstance(gate0, gate_type) + assert gate0.global_shift == 0 + assert gate0.exponent == exponent + if exponent * 2 / 3 % 2 != 0: + assert len(decomposed) == 2 + gate1 = decomposed[1].gate + assert isinstance(gate1, cirq.GlobalPhaseGate) + assert gate1.coefficient == 1j ** (exponent * (4 / 3)) + else: + assert len(decomposed) == 1 + decomposed_circuit = cirq.Circuit(decomposed) + if not cirq.is_parameterized(exponent): + np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10) + else: + resolver = {'s': -1.234} + np.testing.assert_allclose( + cirq.final_state_vector(cirq.Circuit(op), param_resolver=resolver), + cirq.final_state_vector(decomposed_circuit, param_resolver=resolver), + ) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index f970f524688..8f2a91bcc5e 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -802,3 +802,19 @@ def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, con decomposed = cirq.decompose(cg_matrix(*all_qubits)) assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed) np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix)) + + +def test_simplified_controlled_phased_eigengate_decomposition() -> None: + q0, q1 = cirq.LineQubit.range(2) + + # Z gate + op = cirq.ZPowGate(global_shift=0.22).controlled().on(q0, q1) + ops = cirq.decompose(op) + assert ops == [cirq.CZ(q0, q1), cirq.Z(q0) ** 0.22] + np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) + + # X gate + op = cirq.XPowGate(global_shift=0.22).controlled().on(q0, q1) + ops = cirq.decompose(op) + assert ops == [cirq.Y(q1) ** -0.5, cirq.CZ(q0, q1), cirq.Y(q1) ** 0.5, cirq.Z(q0) ** 0.22] + np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 360ed866c0c..97aa4f0b38f 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -19,6 +19,7 @@ import sympy import cirq +from cirq.ops import global_phase_op def test_init(): @@ -302,3 +303,22 @@ def test_global_phase_gate_controlled(coeff, exp): assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate( g, control_values=xor_control_values ) + + +def test_is_identity() -> None: + g = cirq.GlobalPhaseGate(1) + assert g.is_identity + g = cirq.GlobalPhaseGate(1j) + assert not g.is_identity + g = cirq.GlobalPhaseGate(-1) + assert not g.is_identity + + +def test_from_phase_and_exponent() -> None: + g = global_phase_op.from_phase_and_exponent(2.5, 0.5) + assert g.coefficient == np.exp(1.25j * np.pi) + a, b = sympy.symbols('a, b') + g = global_phase_op.from_phase_and_exponent(a, b) + assert g.coefficient == 1j ** (2 * a * b) + g = global_phase_op.from_phase_and_exponent(1 / a, a) + assert g.coefficient == -1 diff --git a/cirq-core/cirq/protocols/decompose_protocol_test.py b/cirq-core/cirq/protocols/decompose_protocol_test.py index b2493c4719a..b9a0c6d615e 100644 --- a/cirq-core/cirq/protocols/decompose_protocol_test.py +++ b/cirq-core/cirq/protocols/decompose_protocol_test.py @@ -445,3 +445,12 @@ def test_decompose_without_context_succeed() -> None: cirq.ops.CleanQubit(1, prefix='_decompose_protocol'), ) ] + + +def test_extracting_global_phases() -> None: + qm = cirq.SimpleQubitManager() + context = cirq.DecompositionContext(qm) + new_context = context.extracting_global_phases() + assert not context.extract_global_phases + assert new_context.extract_global_phases + assert new_context.qubit_manager is qm From 51815a6ba8b910f3063f2f9eaa2715c222b9d476 Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 28 May 2025 21:35:58 -0700 Subject: [PATCH 3/5] A couple nits --- cirq-core/cirq/ops/common_gates_test.py | 1 + cirq-core/cirq/ops/global_phase_op.py | 4 ++-- cirq-core/cirq/ops/three_qubit_gates.py | 16 ++++------------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index 89a0a9a625f..bc9c4ea2d93 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -1333,6 +1333,7 @@ def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamV decomposed = cirq.decompose(op, context=context) gate0 = decomposed[0].gate assert isinstance(gate0, gate_type) + assert isinstance(gate0, cirq.EigenGate) assert gate0.global_shift == 0 assert gate0.exponent == exponent if exponent * 2 / 3 % 2 != 0: diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index 4b997627c1d..897dc8ef0e5 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -131,8 +131,8 @@ def global_phase_operation( def from_phase_and_exponent( - half_turns: 'cirq.TParamVal', exponent: 'cirq.TParamVal' -) -> 'cirq.GlobalPhaseGate': + half_turns: cirq.TParamVal, exponent: cirq.TParamVal +) -> cirq.GlobalPhaseGate: """Creates a GlobalPhaseGate from the global phase and exponent.""" coefficient = 1j ** (2 * half_turns * exponent) coefficient = ( diff --git a/cirq-core/cirq/ops/three_qubit_gates.py b/cirq-core/cirq/ops/three_qubit_gates.py index 7d1509a4756..9262f70dea5 100644 --- a/cirq-core/cirq/ops/three_qubit_gates.py +++ b/cirq-core/cirq/ops/three_qubit_gates.py @@ -107,19 +107,11 @@ def _decompose_(self, qubits): elif not b.is_adjacent(c): a, b = b, a - p = common_gates.T**self._exponent + exp = self._exponent + p = common_gates.T**exp sweep_abc = [common_gates.CNOT(a, b), common_gates.CNOT(b, c)] - global_phase = 1j ** (2 * self.global_shift * self._exponent) - global_phase = ( - complex(global_phase) - if protocols.is_parameterized(global_phase) and global_phase.is_complex - else global_phase - ) - global_phase_operation = ( - [global_phase_op.global_phase_operation(global_phase)] - if protocols.is_parameterized(global_phase) or abs(global_phase - 1.0) > 0 - else [] - ) + global_phase_gate = global_phase_op.from_phase_and_exponent(self.global_shift, exp) + global_phase_operation = [] if global_phase_gate.is_identity else [global_phase_gate()] return global_phase_operation + [ p(a), p(b), From 486e2eff50f1722859ecf00d16ba2cacfc3f4dcf Mon Sep 17 00:00:00 2001 From: daxfo Date: Mon, 16 Jun 2025 11:10:39 -0700 Subject: [PATCH 4/5] A couple nits --- cirq-core/cirq/ops/common_gates.py | 2 +- cirq-core/cirq/ops/common_gates_test.py | 7 +++-- cirq-core/cirq/ops/controlled_gate.py | 2 +- cirq-core/cirq/ops/controlled_gate_test.py | 33 +++++++++++++--------- cirq-core/cirq/ops/global_phase_op.py | 6 ---- cirq-core/cirq/ops/global_phase_op_test.py | 9 ------ cirq-core/cirq/ops/three_qubit_gates.py | 2 +- 7 files changed, 26 insertions(+), 35 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 0ec68275122..48c720cc65f 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -1532,6 +1532,6 @@ def _extract_phase( return NotImplemented result = [gate_class(exponent=gate.exponent).on(*qubits)] phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent) - if not phase_gate.is_identity: + if phase_gate.coefficient != 1: result.append(phase_gate()) return result diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index bc9c4ea2d93..f8d27ece290 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -1328,7 +1328,8 @@ def test_parameterized_pauli_expansion(gate_type, exponent): @pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')]) def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None: context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True) - gate = gate_type(exponent=exponent, global_shift=2 / 3) + test_shift = 2 / 3 # Interesting because e.g. X(shift=2/3) ** 3 == X with no phase + gate = gate_type(exponent=exponent, global_shift=test_shift) op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate))) decomposed = cirq.decompose(op, context=context) gate0 = decomposed[0].gate @@ -1336,11 +1337,11 @@ def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamV assert isinstance(gate0, cirq.EigenGate) assert gate0.global_shift == 0 assert gate0.exponent == exponent - if exponent * 2 / 3 % 2 != 0: + if exponent * test_shift % 2 != 0: assert len(decomposed) == 2 gate1 = decomposed[1].gate assert isinstance(gate1, cirq.GlobalPhaseGate) - assert gate1.coefficient == 1j ** (exponent * (4 / 3)) + assert gate1.coefficient == 1j ** (2 * exponent * test_shift) else: assert len(decomposed) == 1 decomposed_circuit = cirq.Circuit(decomposed) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 9813c67d3ea..bd2967dddb9 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -150,12 +150,12 @@ def _decompose_with_context_( return controlled_sub_gate.on(*qubits) # Try decomposing the subgate next. - # Extract global phases from decomposition, as controlled phases decompose easily. result = protocols.decompose_once_with_qubits( self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False, + # Extract global phases from decomposition, as controlled phases decompose easily. context=context.extracting_global_phases(), ) if result is not NotImplemented: diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 8f2a91bcc5e..84d967b652f 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -804,17 +804,22 @@ def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, con np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix)) -def test_simplified_controlled_phased_eigengate_decomposition() -> None: - q0, q1 = cirq.LineQubit.range(2) - - # Z gate - op = cirq.ZPowGate(global_shift=0.22).controlled().on(q0, q1) - ops = cirq.decompose(op) - assert ops == [cirq.CZ(q0, q1), cirq.Z(q0) ** 0.22] - np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) - - # X gate - op = cirq.XPowGate(global_shift=0.22).controlled().on(q0, q1) - ops = cirq.decompose(op) - assert ops == [cirq.Y(q1) ** -0.5, cirq.CZ(q0, q1), cirq.Y(q1) ** 0.5, cirq.Z(q0) ** 0.22] - np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) +@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate]) +def test_controlled_phase_extracted_before_decomposition(gate_type) -> None: + test_shift = 0.123 # arbitrary + + shifted_gate = gate_type(global_shift=test_shift).controlled() + unshifted_gate = gate_type().controlled() + qs = cirq.LineQubit.range(cirq.num_qubits(shifted_gate)) + shifted_op = shifted_gate.on(*qs) + unshifted_op = unshifted_gate.on(*qs) + shifted_decomposition = cirq.decompose(shifted_op) + unshifted_decomposition = cirq.decompose(unshifted_op) + + # No brute-force calculation. It's the standard decomposition plus Z for the controlled shift. + assert shifted_decomposition == unshifted_decomposition + [cirq.Z(qs[0]) ** test_shift] + + # Sanity check that the decomposition is equivalent + np.testing.assert_allclose( + cirq.unitary(cirq.Circuit(shifted_decomposition)), cirq.unitary(shifted_op), atol=1e-10 + ) diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index 897dc8ef0e5..fe20e78d47a 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -92,12 +92,6 @@ def _resolve_parameters_( coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) return GlobalPhaseGate(coefficient=coefficient) - @property - def is_identity(self) -> bool: - return ( - not protocols.is_parameterized(self._coefficient) and abs(self._coefficient - 1.0) == 0 - ) - def controlled( self, num_controls: int | None = None, diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 97aa4f0b38f..08186728808 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -305,15 +305,6 @@ def test_global_phase_gate_controlled(coeff, exp): ) -def test_is_identity() -> None: - g = cirq.GlobalPhaseGate(1) - assert g.is_identity - g = cirq.GlobalPhaseGate(1j) - assert not g.is_identity - g = cirq.GlobalPhaseGate(-1) - assert not g.is_identity - - def test_from_phase_and_exponent() -> None: g = global_phase_op.from_phase_and_exponent(2.5, 0.5) assert g.coefficient == np.exp(1.25j * np.pi) diff --git a/cirq-core/cirq/ops/three_qubit_gates.py b/cirq-core/cirq/ops/three_qubit_gates.py index 9262f70dea5..a246de59acb 100644 --- a/cirq-core/cirq/ops/three_qubit_gates.py +++ b/cirq-core/cirq/ops/three_qubit_gates.py @@ -111,7 +111,7 @@ def _decompose_(self, qubits): p = common_gates.T**exp sweep_abc = [common_gates.CNOT(a, b), common_gates.CNOT(b, c)] global_phase_gate = global_phase_op.from_phase_and_exponent(self.global_shift, exp) - global_phase_operation = [] if global_phase_gate.is_identity else [global_phase_gate()] + global_phase_operation = [] if global_phase_gate.coefficient == 1 else [global_phase_gate()] return global_phase_operation + [ p(a), p(b), From 72de57792eed13c0259c654fff7b08172d9d7806 Mon Sep 17 00:00:00 2001 From: daxfo Date: Mon, 16 Jun 2025 11:44:05 -0700 Subject: [PATCH 5/5] Fix test with rounding error --- cirq-core/cirq/ops/common_gates_test.py | 25 ++++++++++++---------- cirq-core/cirq/ops/controlled_gate_test.py | 8 ++++++- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index f8d27ece290..6d50582918b 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -1332,24 +1332,27 @@ def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamV gate = gate_type(exponent=exponent, global_shift=test_shift) op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate))) decomposed = cirq.decompose(op, context=context) + + # The first gate should be the original gate, but with shift removed. gate0 = decomposed[0].gate assert isinstance(gate0, gate_type) assert isinstance(gate0, cirq.EigenGate) assert gate0.global_shift == 0 assert gate0.exponent == exponent - if exponent * test_shift % 2 != 0: + if exponent % 3 == 0: + # Since test_shift == 2/3, gate**3 nullifies the phase, leaving only the unphased gate. + assert len(decomposed) == 1 + else: + # Other exponents emit a global phase gate to compensate. assert len(decomposed) == 2 gate1 = decomposed[1].gate assert isinstance(gate1, cirq.GlobalPhaseGate) assert gate1.coefficient == 1j ** (2 * exponent * test_shift) - else: - assert len(decomposed) == 1 + + # Sanity check that the decomposition is equivalent to the original. decomposed_circuit = cirq.Circuit(decomposed) - if not cirq.is_parameterized(exponent): - np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10) - else: - resolver = {'s': -1.234} - np.testing.assert_allclose( - cirq.final_state_vector(cirq.Circuit(op), param_resolver=resolver), - cirq.final_state_vector(decomposed_circuit, param_resolver=resolver), - ) + if cirq.is_parameterized(exponent): + resolver = {'s': -1.234} # arbitrary + op = cirq.resolve_parameters(op, resolver) + decomposed_circuit = cirq.resolve_parameters(decomposed_circuit, resolver) + np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index 84d967b652f..71042e3eaec 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -817,7 +817,13 @@ def test_controlled_phase_extracted_before_decomposition(gate_type) -> None: unshifted_decomposition = cirq.decompose(unshifted_op) # No brute-force calculation. It's the standard decomposition plus Z for the controlled shift. - assert shifted_decomposition == unshifted_decomposition + [cirq.Z(qs[0]) ** test_shift] + assert shifted_decomposition[:-1] == unshifted_decomposition + z_op = shifted_decomposition[-1] + assert z_op.qubits == (qs[0],) + z = z_op.gate + assert isinstance(z, cirq.ZPowGate) + np.testing.assert_approx_equal(z.exponent, test_shift) + assert z.global_shift == 0 # Sanity check that the decomposition is equivalent np.testing.assert_allclose(