From 551b87dc2edabb47e554d78d0564de38aa75b0e4 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 13 Apr 2025 18:28:16 +0200 Subject: [PATCH 1/4] Attempt canonicalization first when decomposing controlled gates. --- cirq-core/cirq/ops/controlled_gate.py | 6 +++++ cirq-core/cirq/ops/controlled_gate_test.py | 30 ++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index e47602ac942..cef3e88a4a9 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -159,6 +159,12 @@ def _decompose_with_context_( self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: control_qubits = list(qubits[: self.num_controls()]) + controlled_sub_gate = self.sub_gate.controlled( + self.num_controls(), self.control_values, self.control_qid_shape + ) + # Prefer the subgate controlled version if available + if controlled_sub_gate.__class__ != self.__class__: + return controlled_sub_gate.on(*qubits) if ( protocols.has_unitary(self.sub_gate) and protocols.num_qubits(self.sub_gate) == 1 diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index ebff6b9c709..ac591bbf2d8 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -494,6 +494,36 @@ def _test_controlled_gate_is_consistent( np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13) +@pytest.mark.parametrize( + 'sub_gate, expected_decomposition', + [ + (cirq.X, [cirq.CX]), + (cirq.CX, [cirq.CCX]), + (cirq.XPowGate(), [cirq.CXPowGate()]), + (cirq.CXPowGate(), [cirq.CCXPowGate()]), + (cirq.Z, [cirq.CZ]), + (cirq.CZ, [cirq.CCZ]), + (cirq.ZPowGate(), [cirq.CZPowGate()]), + (cirq.CZPowGate(), [cirq.CCZPowGate()]), + ], +) +def test_controlled_gate_decomposition_uses_canonical_version(sub_gate, expected_decomposition): + cgate = cirq.ControlledGate(sub_gate, num_controls=1) + qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) + dec = cirq.decompose_once(cgate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition + + +@pytest.mark.parametrize( + 'sub_gate, expected_decomposition', [(cirq.Z, [cirq.CZ]), (cirq.ZPowGate(), [cirq.CZPowGate()])] +) +def test_controlled_gate_full_decomposition(sub_gate, expected_decomposition): + cgate = cirq.ControlledGate(sub_gate, num_controls=1) + qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) + dec = cirq.decompose(cgate.on(*qubits)) + assert [op.gate for op in dec] == expected_decomposition + + def test_pow_inverse(): assert cirq.inverse(CRestricted, None) is None assert cirq.pow(CRestricted, 1.5, None) is None From 416bafd6704241662d9e670881ef4d9305fb4e08 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 13 Apr 2025 21:08:29 +0200 Subject: [PATCH 2/4] Be precise which 2-cycle needs to be avoided in the decomposition. --- cirq-core/cirq/ops/controlled_gate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index cef3e88a4a9..3a782eb94c2 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -163,8 +163,12 @@ def _decompose_with_context_( self.num_controls(), self.control_values, self.control_qid_shape ) # Prefer the subgate controlled version if available - if controlled_sub_gate.__class__ != self.__class__: - return controlled_sub_gate.on(*qubits) + if self != controlled_sub_gate: + # Prevent 2-cycle from appearing in recursive decomposition + if not isinstance(controlled_sub_gate, ControlledGate) or not isinstance( + controlled_sub_gate.sub_gate, common_gates.CZPowGate + ): + return controlled_sub_gate.on(*qubits) if ( protocols.has_unitary(self.sub_gate) and protocols.num_qubits(self.sub_gate) == 1 From 8e385810562d21b333db946ae586a7cf73e75fbd Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 20 Apr 2025 19:19:04 +0300 Subject: [PATCH 3/4] Add a comment that the if condition can be removed once #7241 is resolved. --- cirq-core/cirq/ops/controlled_gate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 3a782eb94c2..278e33038e1 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -164,7 +164,8 @@ def _decompose_with_context_( ) # Prefer the subgate controlled version if available if self != controlled_sub_gate: - # Prevent 2-cycle from appearing in recursive decomposition + # Prevent 2-cycle from appearing in the recursive decomposition + # TODO: Remove after #7241 is resolved if not isinstance(controlled_sub_gate, ControlledGate) or not isinstance( controlled_sub_gate.sub_gate, common_gates.CZPowGate ): From 6ef79563213e978b54d9d3384ffa2a4d7235b0be Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Tue, 20 May 2025 17:25:31 -0700 Subject: [PATCH 4/4] Compare decomposition result with qubits And add type annotations for the new tests. --- cirq-core/cirq/ops/controlled_gate_test.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index b66bf0a0b20..e0c11be56a4 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -509,21 +509,25 @@ def _test_controlled_gate_is_consistent( (cirq.CZPowGate(), [cirq.CCZPowGate()]), ], ) -def test_controlled_gate_decomposition_uses_canonical_version(sub_gate, expected_decomposition): +def test_controlled_gate_decomposition_uses_canonical_version( + sub_gate: cirq.Gate, expected_decomposition: list[cirq.Gate] +): cgate = cirq.ControlledGate(sub_gate, num_controls=1) qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) dec = cirq.decompose_once(cgate.on(*qubits)) - assert [op.gate for op in dec] == expected_decomposition + assert dec == [gate.on(*qubits) for gate in expected_decomposition] @pytest.mark.parametrize( 'sub_gate, expected_decomposition', [(cirq.Z, [cirq.CZ]), (cirq.ZPowGate(), [cirq.CZPowGate()])] ) -def test_controlled_gate_full_decomposition(sub_gate, expected_decomposition): +def test_controlled_gate_full_decomposition( + sub_gate: cirq.Gate, expected_decomposition: list[cirq.Gate] +): cgate = cirq.ControlledGate(sub_gate, num_controls=1) qubits = cirq.LineQubit.range(1 + sub_gate.num_qubits()) dec = cirq.decompose(cgate.on(*qubits)) - assert [op.gate for op in dec] == expected_decomposition + assert dec == [gate.on(*qubits) for gate in expected_decomposition] def test_pow_inverse():