Skip to content

Commit 9a58a8b

Browse files
Support merge symbolized 1 qubit gate transformer (#7393)
Merge single qubit gates for symbolized circuits. It is a updated version of #7149 where tag_transformers and symbolize transformers are already merged. --------- Co-authored-by: eliottrosenberg <61400172+eliottrosenberg@users.noreply.github.com>
1 parent b1ff3ae commit 9a58a8b

File tree

4 files changed

+322
-9
lines changed

4 files changed

+322
-9
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@
380380
merge_operations_to_circuit_op as merge_operations_to_circuit_op,
381381
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
382382
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
383+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
383384
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
384385
optimize_for_target_gateset as optimize_for_target_gateset,
385386
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,

cirq-core/cirq/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
101101
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
102102
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
103+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
103104
)
104105

105106
from cirq.transformers.qubit_management_transformers import (

cirq-core/cirq/transformers/merge_single_qubit_gates.py

Lines changed: 178 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,24 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING
19+
from typing import Callable, cast, Hashable, TYPE_CHECKING
2020

2121
from cirq import circuits, ops, protocols
22-
from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives
22+
from cirq.study.resolver import ParamResolver
23+
from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip
24+
from cirq.transformers import (
25+
align,
26+
merge_k_qubit_gates,
27+
symbolize,
28+
tag_transformers,
29+
transformer_api,
30+
transformer_primitives,
31+
)
2332
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
2433

2534
if TYPE_CHECKING:
35+
import sympy
36+
2637
import cirq
2738

2839

@@ -67,6 +78,7 @@ def merge_single_qubit_gates_to_phxz(
6778
circuit: cirq.AbstractCircuit,
6879
*,
6980
context: cirq.TransformerContext | None = None,
81+
merge_tags_fn: Callable[[cirq.CircuitOperation], list[Hashable]] | None = None,
7082
atol: float = 1e-8,
7183
) -> cirq.Circuit:
7284
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
@@ -77,19 +89,21 @@ def merge_single_qubit_gates_to_phxz(
7789
Args:
7890
circuit: Input circuit to transform. It will not be modified.
7991
context: `cirq.TransformerContext` storing common configurable options for transformers.
92+
merge_tags_fn: A callable returns the tags to be added to the merged operation.
8093
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
8194
dropped, smaller values increase accuracy.
8295
8396
Returns:
8497
Copy of the transformed input circuit.
8598
"""
8699

87-
def rewriter(op: cirq.CircuitOperation) -> cirq.OP_TREE:
88-
u = protocols.unitary(op)
89-
if protocols.num_qubits(op) == 0:
100+
def rewriter(circuit_op: cirq.CircuitOperation) -> cirq.OP_TREE:
101+
u = protocols.unitary(circuit_op)
102+
if protocols.num_qubits(circuit_op) == 0:
90103
return ops.GlobalPhaseGate(u[0, 0]).on()
91-
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol)
92-
return gate(op.qubits[0]) if gate else []
104+
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
105+
phxz_op = gate.on(circuit_op.qubits[0])
106+
return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op
93107

94108
return merge_k_qubit_gates.merge_k_qubit_unitaries(
95109
circuit, k=1, context=context, rewriter=rewriter
@@ -158,3 +172,160 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
158172
deep=context.deep if context else False,
159173
tags_to_ignore=tuple(tags_to_ignore),
160174
).unfreeze(copy=False)
175+
176+
177+
def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep:
178+
new_resolvers: list[cirq.ParamResolver] = []
179+
for resolver in sweep:
180+
param_dict: cirq.ParamMappingType = {s: resolver.value_of(s) for s in symbols}
181+
new_resolvers.append(ParamResolver(param_dict))
182+
return ListSweep(new_resolvers)
183+
184+
185+
def _calc_phxz_sweeps(
186+
symbolized_circuit: cirq.Circuit, resolved_circuits: list[cirq.Circuit]
187+
) -> Sweep:
188+
"""Return the phxz sweep of the symbolized_circuit on resolved_circuits.
189+
190+
Raises:
191+
ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
192+
Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
193+
symbolic `PhasedXZGate` in the `symbolized_circuit`.
194+
"""
195+
196+
def _extract_axz(op: ops.Operation | None) -> tuple[float, float, float]:
197+
if not op or not op.gate or not isinstance(op.gate, ops.IdentityGate | ops.PhasedXZGate):
198+
raise ValueError(f"Expect a PhasedXZGate or IdentityGate on op {op}.")
199+
if isinstance(op.gate, ops.IdentityGate):
200+
return 0.0, 0.0, 0.0 # Identity gate's a, x, z in PhasedXZ
201+
return op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent
202+
203+
values_by_params: dict[sympy.Symbol, tuple[float, ...]] = {}
204+
for mid, moment in enumerate(symbolized_circuit):
205+
for op in moment.operations:
206+
if op.gate and isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op):
207+
sa, sx, sz = op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent
208+
values_by_params[sa], values_by_params[sx], values_by_params[sz] = zip(
209+
*[_extract_axz(c[mid].operation_at(op.qubits[0])) for c in resolved_circuits]
210+
)
211+
212+
return dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params))
213+
214+
215+
def merge_single_qubit_gates_to_phxz_symbolized(
216+
circuit: cirq.AbstractCircuit,
217+
*,
218+
context: cirq.TransformerContext | None = None,
219+
sweep: Sweep,
220+
atol: float = 1e-8,
221+
) -> tuple[cirq.Circuit, Sweep]:
222+
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
223+
the consecutive gates is symbolized.
224+
225+
Example:
226+
>>> q0, q1 = cirq.LineQubit.range(2)
227+
>>> c = cirq.Circuit(\
228+
cirq.X(q0),\
229+
cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
230+
cirq.Y(q0)**sympy.Symbol("y_exp"),\
231+
cirq.X(q0))
232+
>>> print(c)
233+
0: ───X───@──────────Y^y_exp───X───
234+
235+
1: ───────@^cz_exp─────────────────
236+
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
237+
c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
238+
cirq.Points(key="y_exp", points=[0, 1])))
239+
>>> print(new_circuit)
240+
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
241+
242+
1: ────────────────────────@^cz_exp──────────────────────────
243+
>>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
244+
>>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
245+
246+
Args:
247+
circuit: Input circuit to transform. It will not be modified.
248+
context: `cirq.TransformerContext` storing common configurable options for transformers.
249+
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
250+
based on the transformation.
251+
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
252+
dropped, smaller values increase accuracy.
253+
254+
Returns:
255+
Copy of the transformed input circuit.
256+
"""
257+
deep = context.deep if context else False
258+
259+
# Tag symbolized single-qubit op.
260+
symbolized_single_tag = "_tmp_symbolize_tag"
261+
262+
circuit_tagged = transformer_primitives.map_operations(
263+
circuit,
264+
lambda op, _: (
265+
op.with_tags(symbolized_single_tag)
266+
if protocols.is_parameterized(op) and len(op.qubits) == 1
267+
else op
268+
),
269+
deep=deep,
270+
)
271+
272+
# Step 0, isolate single qubit symbols and resolve the circuit on them.
273+
single_qubit_gate_symbols: set[sympy.Symbol] = set().union(
274+
*[
275+
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
276+
for op in circuit_tagged.all_operations()
277+
]
278+
)
279+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
280+
remaining_symbols: set[sympy.Symbol] = set(
281+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
282+
)
283+
# If all single qubit gates are not parameterized, call the nonparamerized version of
284+
# the transformer.
285+
if not single_qubit_gate_symbols:
286+
return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep)
287+
sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols)
288+
# Get all resolved circuits from all sets of resolvers in sweep_of_single.
289+
resolved_circuits = [
290+
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
291+
]
292+
293+
# Step 1, merge single qubit gates per resolved circuit, preserving
294+
# the symbolized_single_tag to indicate the operator is a merged one.
295+
merged_circuits: list[cirq.Circuit] = [
296+
merge_single_qubit_gates_to_phxz(
297+
c,
298+
context=context,
299+
merge_tags_fn=lambda circuit_op: (
300+
[symbolized_single_tag]
301+
if any(
302+
symbolized_single_tag in set(op.tags)
303+
for op in circuit_op.circuit.all_operations()
304+
)
305+
else []
306+
),
307+
atol=atol,
308+
)
309+
for c in resolved_circuits
310+
]
311+
312+
# Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
313+
new_circuit = tag_transformers.remove_tags( # remove the temp tags used to track merges
314+
symbolize.symbolize_single_qubit_gates_by_indexed_tags(
315+
tag_transformers.index_tags( # index all 1-qubit-ops merged from ops with symbols
316+
merged_circuits[0],
317+
context=transformer_api.TransformerContext(deep=deep),
318+
target_tags={symbolized_single_tag},
319+
),
320+
symbolize_tag=symbolize.SymbolizeTag(prefix=symbolized_single_tag),
321+
),
322+
remove_if=lambda tag: str(tag).startswith(symbolized_single_tag),
323+
)
324+
325+
# Step 3, get N sets of parameterizations as new_sweep.
326+
new_sweep = Zip(
327+
_calc_phxz_sweeps(new_circuit, merged_circuits), # phxz sweeps
328+
_sweep_on_symbols(sweep, remaining_symbols), # remaining sweeps
329+
)
330+
331+
return align.align_right(new_circuit), new_sweep

0 commit comments

Comments
 (0)