Skip to content

Commit cc566ee

Browse files
committed
Support merge symbolized 1 qubit gate.
1 parent 1f17390 commit cc566ee

File tree

3 files changed

+340
-11
lines changed

3 files changed

+340
-11
lines changed

cirq-core/cirq/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
102102
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
103103
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
104+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
104105
)
105106

106107
from cirq.transformers.qubit_management_transformers import (

cirq-core/cirq/transformers/merge_single_qubit_gates.py

Lines changed: 211 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING
19+
from typing import Callable, cast, Hashable, List, Tuple, TYPE_CHECKING
20+
21+
import sympy
2022

2123
from cirq import circuits, ops, protocols
22-
from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives
24+
from cirq.study.resolver import ParamResolver
25+
from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip
26+
from cirq.transformers import (
27+
align,
28+
merge_k_qubit_gates,
29+
transformer_api,
30+
transformer_primitives,
31+
symbolize,
32+
tag_transformers,
33+
)
2334
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
2435

2536
if TYPE_CHECKING:
@@ -67,8 +78,9 @@ def merge_single_qubit_gates_to_phxz(
6778
circuit: cirq.AbstractCircuit,
6879
*,
6980
context: cirq.TransformerContext | None = None,
81+
merge_tags_fn: Callable[[cirq.CircuitOperation], List[Hashable]] | None = None,
7082
atol: float = 1e-8,
71-
) -> cirq.Circuit:
83+
) -> 'cirq.Circuit':
7284
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
7385
7486
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
@@ -77,19 +89,21 @@ def merge_single_qubit_gates_to_phxz(
7789
Args:
7890
circuit: Input circuit to transform. It will not be modified.
7991
context: `cirq.TransformerContext` storing common configurable options for transformers.
92+
merge_tags_fn: A callable returns the tags to be added to the merged operation.
8093
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
8194
dropped, smaller values increase accuracy.
8295
8396
Returns:
8497
Copy of the transformed input circuit.
8598
"""
8699

87-
def rewriter(op: cirq.CircuitOperation) -> cirq.OP_TREE:
88-
u = protocols.unitary(op)
89-
if protocols.num_qubits(op) == 0:
100+
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
101+
u = protocols.unitary(circuit_op)
102+
if protocols.num_qubits(circuit_op) == 0:
90103
return ops.GlobalPhaseGate(u[0, 0]).on()
91-
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol)
92-
return gate(op.qubits[0]) if gate else []
104+
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I
105+
phxz_op = gate.on(circuit_op.qubits[0])
106+
return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op
93107

94108
return merge_k_qubit_gates.merge_k_qubit_unitaries(
95109
circuit, k=1, context=context, rewriter=rewriter
@@ -154,3 +168,192 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
154168
deep=context.deep if context else False,
155169
tags_to_ignore=tuple(tags_to_ignore),
156170
).unfreeze(copy=False)
171+
172+
173+
def _all_tags_startswith(circuit: cirq.AbstractCircuit, startswith: str):
174+
tag_set: set[Hashable] = set()
175+
for op in circuit.all_operations():
176+
for tag in op.tags:
177+
if str(tag).startswith(startswith):
178+
tag_set.add(tag)
179+
return tag_set
180+
181+
182+
def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep:
183+
new_resolvers: List[cirq.ParamResolver] = []
184+
for resolver in sweep:
185+
param_dict: 'cirq.ParamMappingType' = {s: resolver.value_of(s) for s in symbols}
186+
new_resolvers.append(ParamResolver(param_dict))
187+
return ListSweep(new_resolvers)
188+
189+
190+
def _parameterize_phxz_in_circuits(
191+
circuit_list: List['cirq.Circuit'],
192+
merge_tag_prefix: str,
193+
phxz_symbols: set[sympy.Symbol],
194+
remaining_symbols: set[sympy.Symbol],
195+
sweep: Sweep,
196+
) -> Sweep:
197+
"""Parameterizes the circuits and returns a new sweep."""
198+
values_by_params: dict[str, List[float]] = {**{str(s): [] for s in phxz_symbols}}
199+
200+
for circuit in circuit_list:
201+
for op in circuit.all_operations():
202+
the_merge_tag: str | None = None
203+
for tag in op.tags:
204+
if str(tag).startswith(merge_tag_prefix):
205+
the_merge_tag = str(tag)
206+
if not the_merge_tag:
207+
continue
208+
sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1]
209+
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
210+
if isinstance(op.gate, ops.PhasedXZGate):
211+
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
212+
elif op.gate is not ops.I:
213+
raise RuntimeError(
214+
f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
215+
f" but got {op.gate}."
216+
)
217+
values_by_params[f"x{sid}"].append(x)
218+
values_by_params[f"z{sid}"].append(z)
219+
values_by_params[f"a{sid}"].append(a)
220+
221+
return Zip(
222+
dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)),
223+
_sweep_on_symbols(sweep, remaining_symbols),
224+
)
225+
226+
227+
def merge_single_qubit_gates_to_phxz_symbolized(
228+
circuit: cirq.AbstractCircuit,
229+
*,
230+
context: cirq.TransformerContext | None = None,
231+
sweep: Sweep,
232+
atol: float = 1e-8,
233+
) -> Tuple[cirq.Circuit, Sweep]:
234+
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
235+
the consecutive gates is symbolized.
236+
237+
Example:
238+
>>> q0, q1 = cirq.LineQubit.range(2)
239+
>>> c = cirq.Circuit(\
240+
cirq.X(q0),\
241+
cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
242+
cirq.Y(q0)**sympy.Symbol("y_exp"),\
243+
cirq.X(q0))
244+
>>> print(c)
245+
0: ───X───@──────────Y^y_exp───X───
246+
247+
1: ───────@^cz_exp─────────────────
248+
>>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
249+
c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
250+
cirq.Points(key="y_exp", points=[0, 1])))
251+
>>> print(new_circuit)
252+
0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
253+
254+
1: ────────────────────────@^cz_exp──────────────────────────
255+
>>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
256+
>>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
257+
258+
Args:
259+
circuit: Input circuit to transform. It will not be modified.
260+
context: `cirq.TransformerContext` storing common configurable options for transformers.
261+
sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
262+
based on the transformation.
263+
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
264+
dropped, smaller values increase accuracy.
265+
266+
Returns:
267+
Copy of the transformed input circuit.
268+
"""
269+
deep = context.deep if context else False
270+
271+
# Tag symbolized single-qubit op.
272+
symbolized_single_tag = "_tmp_symbolize_tag"
273+
274+
circuit_tagged = transformer_primitives.map_operations(
275+
circuit,
276+
lambda op, _: (
277+
op.with_tags(symbolized_single_tag)
278+
if protocols.is_parameterized(op) and len(op.qubits) == 1
279+
else op
280+
),
281+
deep=deep,
282+
)
283+
284+
# Step 0, isolate single qubit symbols and resolve the circuit on them.
285+
single_qubit_gate_symbols: set[sympy.Symbol] = set().union(
286+
*[
287+
protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set()
288+
for op in circuit_tagged.all_operations()
289+
]
290+
)
291+
# If all single qubit gates are not parameterized, call the nonparamerized version of
292+
# the transformer.
293+
if not single_qubit_gate_symbols:
294+
return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep)
295+
sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols)
296+
# Get all resolved circuits from all sets of resolvers in sweep_of_single.
297+
resolved_circuits = [
298+
protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single
299+
]
300+
301+
# Step 1, merge single qubit gates per resolved circuit, preserving
302+
# the symbolized_single_tag with indexes.
303+
merged_circuits: List['cirq.Circuit'] = []
304+
for resolved_circuit in resolved_circuits:
305+
merged_circuit = tag_transformers.index_tags(
306+
merge_single_qubit_gates_to_phxz(
307+
resolved_circuit,
308+
context=context,
309+
merge_tags_fn=lambda circuit_op: (
310+
[symbolized_single_tag]
311+
if any(
312+
symbolized_single_tag in set(op.tags)
313+
for op in circuit_op.circuit.all_operations()
314+
)
315+
else []
316+
),
317+
atol=atol,
318+
),
319+
context=transformer_api.TransformerContext(deep=deep),
320+
target_tags={symbolized_single_tag},
321+
)
322+
merged_circuits.append(merged_circuit)
323+
324+
if not all(
325+
_all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
326+
== _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag)
327+
for merged_circuit in merged_circuits
328+
):
329+
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
330+
331+
# Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag.
332+
new_circuit = align.align_right(
333+
tag_transformers.remove_tags(
334+
symbolize.symbolize_single_qubit_gates_by_indexed_tags(
335+
merged_circuits[0],
336+
symbolize_tag=symbolize.SymbolizeTag(prefix=symbolized_single_tag),
337+
),
338+
remove_if=lambda tag: str(tag).startswith(symbolized_single_tag),
339+
)
340+
)
341+
342+
# Step 3, get N sets of parameterizations as new_sweep.
343+
phxz_symbols: set[sympy.Symbol] = set().union(
344+
*[
345+
set(
346+
[sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) for s in ["x", "z", "a"]]
347+
)
348+
for tag in _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
349+
]
350+
)
351+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
352+
remaining_symbols: set[sympy.Symbol] = set(
353+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
354+
)
355+
new_sweep = _parameterize_phxz_in_circuits(
356+
merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep
357+
)
358+
359+
return new_circuit, new_sweep

0 commit comments

Comments
 (0)