-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Simplify decomposition of controlled eigengates with global phase #7383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1322,3 +1322,33 @@ def test_parameterized_pauli_expansion(gate_type, exponent): | |
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5}) | ||
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5}) | ||
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved)) | ||
|
||
|
||
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate]) | ||
@pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')]) | ||
def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None: | ||
context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True) | ||
gate = gate_type(exponent=exponent, global_shift=2 / 3) | ||
op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate))) | ||
decomposed = cirq.decompose(op, context=context) | ||
gate0 = decomposed[0].gate | ||
assert isinstance(gate0, gate_type) | ||
assert isinstance(gate0, cirq.EigenGate) | ||
assert gate0.global_shift == 0 | ||
assert gate0.exponent == exponent | ||
if exponent * 2 / 3 % 2 != 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be simplified to |
||
assert len(decomposed) == 2 | ||
gate1 = decomposed[1].gate | ||
assert isinstance(gate1, cirq.GlobalPhaseGate) | ||
assert gate1.coefficient == 1j ** (exponent * (4 / 3)) | ||
else: | ||
assert len(decomposed) == 1 | ||
decomposed_circuit = cirq.Circuit(decomposed) | ||
if not cirq.is_parameterized(exponent): | ||
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10) | ||
else: | ||
resolver = {'s': -1.234} | ||
np.testing.assert_allclose( | ||
cirq.final_state_vector(cirq.Circuit(op), param_resolver=resolver), | ||
cirq.final_state_vector(decomposed_circuit, param_resolver=resolver), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
control_values as cv, | ||
controlled_operation as cop, | ||
diagonal_gate as dg, | ||
global_phase_op as gp, | ||
op_tree, | ||
raw_types, | ||
) | ||
|
@@ -139,19 +138,35 @@ def num_controls(self) -> int: | |
def _qid_shape_(self) -> tuple[int, ...]: | ||
return self.control_qid_shape + protocols.qid_shape(self.sub_gate) | ||
|
||
def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE: | ||
return self._decompose_with_context_(qubits) | ||
|
||
def _decompose_with_context_( | ||
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None | ||
) -> None | NotImplementedType | cirq.OP_TREE: | ||
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | ||
) -> NotImplementedType | cirq.OP_TREE: | ||
control_qubits = list(qubits[: self.num_controls()]) | ||
controlled_sub_gate = self.sub_gate.controlled( | ||
self.num_controls(), self.control_values, self.control_qid_shape | ||
) | ||
# Prefer the subgate controlled version if available | ||
if self != controlled_sub_gate: | ||
return controlled_sub_gate.on(*qubits) | ||
|
||
# Try decomposing the subgate next. | ||
# Extract global phases from decomposition, as controlled phases decompose easily. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this comment five lines down. |
||
result = protocols.decompose_once_with_qubits( | ||
self.sub_gate, | ||
qubits[self.num_controls() :], | ||
NotImplemented, | ||
flatten=False, | ||
context=context.extracting_global_phases(), | ||
) | ||
if result is not NotImplemented: | ||
return op_tree.transform_op_tree( | ||
result, | ||
lambda op: op.controlled_by( | ||
*qubits[: self.num_controls()], control_values=self.control_values | ||
), | ||
) | ||
|
||
# Finally try brute-force on the unitary. | ||
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits): | ||
n_qubits = protocols.num_qubits(self.sub_gate) | ||
# Case 1: Global Phase (1x1 Matrix) | ||
|
@@ -173,54 +188,9 @@ def _decompose_with_context_( | |
protocols.unitary(self.sub_gate), control_qubits, qubits[-1] | ||
) | ||
return invert_ops + decomposed_ops + invert_ops | ||
if isinstance(self.sub_gate, common_gates.CZPowGate): | ||
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent) | ||
num_controls = self.num_controls() + 1 | ||
control_values = self.control_values & cv.ProductOfSums(((1,),)) | ||
control_qid_shape = self.control_qid_shape + (2,) | ||
controlled_z = ( | ||
z_sub_gate.controlled( | ||
num_controls=num_controls, | ||
control_values=control_values, | ||
control_qid_shape=control_qid_shape, | ||
) | ||
if protocols.is_parameterized(self) | ||
else ControlledGate( | ||
z_sub_gate, | ||
num_controls=num_controls, | ||
control_values=control_values, | ||
control_qid_shape=control_qid_shape, | ||
) | ||
) | ||
if self != controlled_z: | ||
result = controlled_z.on(*qubits) | ||
if self.sub_gate.global_shift == 0: | ||
return result | ||
# Reconstruct the controlled global shift of the subgate. | ||
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift | ||
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift)) | ||
controlled_phase_op = phase_gate.controlled( | ||
num_controls=self.num_controls(), | ||
control_values=self.control_values, | ||
control_qid_shape=self.control_qid_shape, | ||
).on(*control_qubits) | ||
return [result, controlled_phase_op] | ||
result = protocols.decompose_once_with_qubits( | ||
self.sub_gate, | ||
qubits[self.num_controls() :], | ||
NotImplemented, | ||
flatten=False, | ||
context=context, | ||
) | ||
if result is NotImplemented: | ||
return NotImplemented | ||
|
||
return op_tree.transform_op_tree( | ||
result, | ||
lambda op: op.controlled_by( | ||
*qubits[: self.num_controls()], control_values=self.control_values | ||
), | ||
) | ||
# If nothing works, return `NotImplemented`. | ||
return NotImplemented | ||
|
||
def on(self, *qubits: cirq.Qid) -> cop.ControlledOperation: | ||
if len(qubits) == 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -802,3 +802,19 @@ def test_controlled_global_phase_matrix_gate_decomposes(num_controls, angle, con | |
decomposed = cirq.decompose(cg_matrix(*all_qubits)) | ||
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed) | ||
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix)) | ||
|
||
|
||
def test_simplified_controlled_phased_eigengate_decomposition() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the other affected gates? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it could be made more generic and parameterized: the decomposition should always be equal to the decomposition of the equivalent gate without global shift, plus a final Z gate. |
||
q0, q1 = cirq.LineQubit.range(2) | ||
|
||
# Z gate | ||
op = cirq.ZPowGate(global_shift=0.22).controlled().on(q0, q1) | ||
ops = cirq.decompose(op) | ||
assert ops == [cirq.CZ(q0, q1), cirq.Z(q0) ** 0.22] | ||
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) | ||
|
||
# X gate | ||
op = cirq.XPowGate(global_shift=0.22).controlled().on(q0, q1) | ||
ops = cirq.decompose(op) | ||
assert ops == [cirq.Y(q1) ** -0.5, cirq.CZ(q0, q1), cirq.Y(q1) ** 0.5, cirq.Z(q0) ** 0.22] | ||
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(cirq.Circuit(ops))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,6 +92,12 @@ def _resolve_parameters_( | |
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) | ||
return GlobalPhaseGate(coefficient=coefficient) | ||
|
||
@property | ||
def is_identity(self) -> bool: | ||
return ( | ||
not protocols.is_parameterized(self._coefficient) and abs(self._coefficient - 1.0) == 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved from |
||
) | ||
|
||
def controlled( | ||
self, | ||
num_controls: int | None = None, | ||
|
@@ -122,3 +128,16 @@ def global_phase_operation( | |
) -> cirq.GateOperation: | ||
"""Creates an operation that represents a global phase on the state.""" | ||
return GlobalPhaseGate(coefficient, atol)() | ||
|
||
|
||
def from_phase_and_exponent( | ||
half_turns: cirq.TParamVal, exponent: cirq.TParamVal | ||
) -> cirq.GlobalPhaseGate: | ||
"""Creates a GlobalPhaseGate from the global phase and exponent.""" | ||
coefficient = 1j ** (2 * half_turns * exponent) | ||
coefficient = ( | ||
complex(coefficient) | ||
if isinstance(coefficient, sympy.Expr) and coefficient.is_complex | ||
else coefficient | ||
) | ||
return GlobalPhaseGate(coefficient) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make
2/3
a variable since it's used multiple times, and refer back to the docstring of_extract_phase
to explain why it's interesting.