Skip to content

Simplify decomposition of controlled eigengates with global phase #7383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@
from cirq import protocols, value
from cirq._compat import proper_repr
from cirq._doc import document
from cirq.ops import control_values as cv, controlled_gate, eigen_gate, gate_features, raw_types
from cirq.ops import (
control_values as cv,
controlled_gate,
eigen_gate,
gate_features,
global_phase_op,
raw_types,
)
from cirq.ops.measurement_gate import MeasurementGate
from cirq.ops.swap_gates import ISWAP, ISwapPowGate, SWAP, SwapPowGate

Expand Down Expand Up @@ -235,6 +242,11 @@ def controlled(
return cirq.CCXPowGate(exponent=self._exponent)
return result

def _decompose_with_context_(
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
) -> list[cirq.Operation] | NotImplementedType:
return _extract_phase(self, XPowGate, qubits, context)

def _pauli_expansion_(self) -> value.LinearDict[str]:
if self._dimension != 2:
return NotImplemented # pragma: no cover
Expand Down Expand Up @@ -487,6 +499,11 @@ def __repr__(self) -> str:
f'global_shift={self._global_shift!r})'
)

def _decompose_with_context_(
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
) -> list[cirq.Operation] | NotImplementedType:
return _extract_phase(self, YPowGate, qubits, context)


class Ry(YPowGate):
r"""A gate with matrix $e^{-i Y t/2}$ that rotates around the Y axis of the Bloch sphere by $t$.
Expand Down Expand Up @@ -699,6 +716,11 @@ def controlled(
return cirq.CCZPowGate(exponent=self._exponent)
return result

def _decompose_with_context_(
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
) -> list[cirq.Operation] | NotImplementedType:
return _extract_phase(self, ZPowGate, qubits, context)

def _qid_shape_(self) -> tuple[int, ...]:
return (self._dimension,)

Expand Down Expand Up @@ -1131,6 +1153,11 @@ def controlled(
control_qid_shape=result.control_qid_shape + (2,),
)

def _decompose_with_context_(
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
) -> list[cirq.Operation] | NotImplementedType:
return _extract_phase(self, CZPowGate, qubits, context)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
return protocols.CircuitDiagramInfo(
wire_symbols=('@', '@'), exponent=self._diagram_exponent(args)
Expand Down Expand Up @@ -1486,3 +1513,25 @@ def _phased_x_or_pauli_gate(
case 0.5:
return YPowGate(exponent=exponent)
return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)


def _extract_phase(
gate: cirq.EigenGate,
gate_class: type,
qubits: tuple[cirq.Qid, ...],
context: cirq.DecompositionContext,
) -> list[cirq.Operation] | NotImplementedType:
"""Extracts the global phase field to its own gate, or absorbs it if it has no effect.

This is for use within the decompose handlers, and will return `NotImplemented` if there is no
global phase, implying it is already in its simplest form. It will return a list, with the
original op minus any global phase first, and the global phase op second. If the resulting
global phase is empty (can happen for example in `XPowGate(global_phase=2/3)**3`), then it is
excluded from the return value."""
if not context.extract_global_phases or gate.global_shift == 0:
return NotImplemented
result = [gate_class(exponent=gate.exponent).on(*qubits)]
phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent)
if not phase_gate.is_identity:
result.append(phase_gate())
return result
30 changes: 30 additions & 0 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,3 +1322,33 @@ def test_parameterized_pauli_expansion(gate_type, exponent):
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5})
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5})
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved))


@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate])
@pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')])
def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None:
context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True)
gate = gate_type(exponent=exponent, global_shift=2 / 3)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make 2/3 a variable since it's used multiple times, and refer back to the docstring of _extract_phase to explain why it's interesting.

op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate)))
decomposed = cirq.decompose(op, context=context)
gate0 = decomposed[0].gate
assert isinstance(gate0, gate_type)
assert isinstance(gate0, cirq.EigenGate)
assert gate0.global_shift == 0
assert gate0.exponent == exponent
if exponent * 2 / 3 % 2 != 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be simplified to if exponent % 3:

assert len(decomposed) == 2
gate1 = decomposed[1].gate
assert isinstance(gate1, cirq.GlobalPhaseGate)
assert gate1.coefficient == 1j ** (exponent * (4 / 3))
else:
assert len(decomposed) == 1
decomposed_circuit = cirq.Circuit(decomposed)
if not cirq.is_parameterized(exponent):
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10)
else:
resolver = {'s': -1.234}
np.testing.assert_allclose(
cirq.final_state_vector(cirq.Circuit(op), param_resolver=resolver),
cirq.final_state_vector(decomposed_circuit, param_resolver=resolver),
)
76 changes: 23 additions & 53 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
control_values as cv,
controlled_operation as cop,
diagonal_gate as dg,
global_phase_op as gp,
op_tree,
raw_types,
)
Expand Down Expand Up @@ -139,19 +138,35 @@ def num_controls(self) -> int:
def _qid_shape_(self) -> tuple[int, ...]:
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)

def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE:
return self._decompose_with_context_(qubits)

def _decompose_with_context_(
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None
) -> None | NotImplementedType | cirq.OP_TREE:
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
) -> NotImplementedType | cirq.OP_TREE:
control_qubits = list(qubits[: self.num_controls()])
controlled_sub_gate = self.sub_gate.controlled(
self.num_controls(), self.control_values, self.control_qid_shape
)
# Prefer the subgate controlled version if available
if self != controlled_sub_gate:
return controlled_sub_gate.on(*qubits)

# Try decomposing the subgate next.
# Extract global phases from decomposition, as controlled phases decompose easily.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this comment five lines down.

result = protocols.decompose_once_with_qubits(
self.sub_gate,
qubits[self.num_controls() :],
NotImplemented,
flatten=False,
context=context.extracting_global_phases(),
)
if result is not NotImplemented:
return op_tree.transform_op_tree(
result,
lambda op: op.controlled_by(
*qubits[: self.num_controls()], control_values=self.control_values
),
)

# Finally try brute-force on the unitary.
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
n_qubits = protocols.num_qubits(self.sub_gate)
# Case 1: Global Phase (1x1 Matrix)
Expand All @@ -173,54 +188,9 @@ def _decompose_with_context_(
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
)
return invert_ops + decomposed_ops + invert_ops
if isinstance(self.sub_gate, common_gates.CZPowGate):
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
num_controls = self.num_controls() + 1
control_values = self.control_values & cv.ProductOfSums(((1,),))
control_qid_shape = self.control_qid_shape + (2,)
controlled_z = (
z_sub_gate.controlled(
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
if protocols.is_parameterized(self)
else ControlledGate(
z_sub_gate,
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)
)
if self != controlled_z:
result = controlled_z.on(*qubits)
if self.sub_gate.global_shift == 0:
return result
# Reconstruct the controlled global shift of the subgate.
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift))
controlled_phase_op = phase_gate.controlled(
num_controls=self.num_controls(),
control_values=self.control_values,
control_qid_shape=self.control_qid_shape,
).on(*control_qubits)
return [result, controlled_phase_op]
result = protocols.decompose_once_with_qubits(
self.sub_gate,
qubits[self.num_controls() :],
NotImplemented,
flatten=False,
context=context,
)
if result is NotImplemented:
return NotImplemented

return op_tree.transform_op_tree(
result,
lambda op: op.controlled_by(
*qubits[: self.num_controls()], control_values=self.control_values
),
)
# If nothing works, return `NotImplemented`.
return NotImplemented

def on(self, *qubits: cirq.Qid) -> cop.ControlledOperation:
if len(qubits) == 0:
Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,19 @@ def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, con
decomposed = cirq.decompose(cg_matrix(*all_qubits))
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed)
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))


def test_simplified_controlled_phased_eigengate_decomposition() -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the other affected gates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it could be made more generic and parameterized: the decomposition should always be equal to the decomposition of the equivalent gate without global shift, plus a final Z gate.

q0, q1 = cirq.LineQubit.range(2)

# Z gate
op = cirq.ZPowGate(global_shift=0.22).controlled().on(q0, q1)
ops = cirq.decompose(op)
assert ops == [cirq.CZ(q0, q1), cirq.Z(q0) ** 0.22]
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops)))

# X gate
op = cirq.XPowGate(global_shift=0.22).controlled().on(q0, q1)
ops = cirq.decompose(op)
assert ops == [cirq.Y(q1) ** -0.5, cirq.CZ(q0, q1), cirq.Y(q1) ** 0.5, cirq.Z(q0) ** 0.22]
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops)))
19 changes: 19 additions & 0 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def _resolve_parameters_(
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
return GlobalPhaseGate(coefficient=coefficient)

@property
def is_identity(self) -> bool:
return (
not protocols.is_parameterized(self._coefficient) and abs(self._coefficient - 1.0) == 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved from three_qubit_gates.py line 120, but now that I look at it again, it seems equivalent to self.coefficient == 1. Will try that and see if it works. If so, that's simple enough that the method probably doesn't even need to exist, and can just be inlined.

)

def controlled(
self,
num_controls: int | None = None,
Expand Down Expand Up @@ -122,3 +128,16 @@ def global_phase_operation(
) -> cirq.GateOperation:
"""Creates an operation that represents a global phase on the state."""
return GlobalPhaseGate(coefficient, atol)()


def from_phase_and_exponent(
half_turns: cirq.TParamVal, exponent: cirq.TParamVal
) -> cirq.GlobalPhaseGate:
"""Creates a GlobalPhaseGate from the global phase and exponent."""
coefficient = 1j ** (2 * half_turns * exponent)
coefficient = (
complex(coefficient)
if isinstance(coefficient, sympy.Expr) and coefficient.is_complex
else coefficient
)
return GlobalPhaseGate(coefficient)
20 changes: 20 additions & 0 deletions cirq-core/cirq/ops/global_phase_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sympy

import cirq
from cirq.ops import global_phase_op


def test_init():
Expand Down Expand Up @@ -302,3 +303,22 @@ def test_global_phase_gate_controlled(coeff, exp):
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
g, control_values=xor_control_values
)


def test_is_identity() -> None:
g = cirq.GlobalPhaseGate(1)
assert g.is_identity
g = cirq.GlobalPhaseGate(1j)
assert not g.is_identity
g = cirq.GlobalPhaseGate(-1)
assert not g.is_identity


def test_from_phase_and_exponent() -> None:
g = global_phase_op.from_phase_and_exponent(2.5, 0.5)
assert g.coefficient == np.exp(1.25j * np.pi)
a, b = sympy.symbols('a, b')
g = global_phase_op.from_phase_and_exponent(a, b)
assert g.coefficient == 1j ** (2 * a * b)
g = global_phase_op.from_phase_and_exponent(1 / a, a)
assert g.coefficient == -1
16 changes: 4 additions & 12 deletions cirq-core/cirq/ops/three_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,11 @@ def _decompose_(self, qubits):
elif not b.is_adjacent(c):
a, b = b, a

p = common_gates.T**self._exponent
exp = self._exponent
p = common_gates.T**exp
sweep_abc = [common_gates.CNOT(a, b), common_gates.CNOT(b, c)]
global_phase = 1j ** (2 * self.global_shift * self._exponent)
global_phase = (
complex(global_phase)
if protocols.is_parameterized(global_phase) and global_phase.is_complex
else global_phase
)
global_phase_operation = (
[global_phase_op.global_phase_operation(global_phase)]
if protocols.is_parameterized(global_phase) or abs(global_phase - 1.0) > 0
else []
)
global_phase_gate = global_phase_op.from_phase_and_exponent(self.global_shift, exp)
global_phase_operation = [] if global_phase_gate.is_identity else [global_phase_gate()]
return global_phase_operation + [
p(a),
p(b),
Expand Down
7 changes: 7 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,16 @@ class DecompositionContext:
Args:
qubit_manager: A `cirq.QubitManager` instance to allocate clean / dirty ancilla qubits as
part of the decompose protocol.
extract_global_phases: If set, will extract the global phases from
`DECOMPOSE_TARGET_GATESET` into independent global phase operations.
"""

qubit_manager: cirq.QubitManager
extract_global_phases: bool = False

def extracting_global_phases(self) -> DecompositionContext:
"""Returns a copy with the `extract_global_phases` field set."""
return dataclasses.replace(self, extract_global_phases=True)


class SupportsDecompose(Protocol):
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,12 @@ def test_decompose_without_context_succeed() -> None:
cirq.ops.CleanQubit(1, prefix='_decompose_protocol'),
)
]


def test_extracting_global_phases() -> None:
qm = cirq.SimpleQubitManager()
context = cirq.DecompositionContext(qm)
new_context = context.extracting_global_phases()
assert not context.extract_global_phases
assert new_context.extract_global_phases
assert new_context.qubit_manager is qm