Skip to content

Commit c283401

Browse files
committed
Refactor
1 parent cc566ee commit c283401

File tree

3 files changed

+80
-114
lines changed

3 files changed

+80
-114
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@
380380
merge_operations_to_circuit_op as merge_operations_to_circuit_op,
381381
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
382382
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
383+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
383384
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
384385
optimize_for_target_gateset as optimize_for_target_gateset,
385386
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,

cirq-core/cirq/transformers/merge_single_qubit_gates.py

Lines changed: 65 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,24 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Callable, cast, Hashable, List, Tuple, TYPE_CHECKING
20-
21-
import sympy
19+
from typing import Callable, cast, Hashable, TYPE_CHECKING
2220

2321
from cirq import circuits, ops, protocols
2422
from cirq.study.resolver import ParamResolver
2523
from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip
2624
from cirq.transformers import (
2725
align,
2826
merge_k_qubit_gates,
29-
transformer_api,
30-
transformer_primitives,
3127
symbolize,
3228
tag_transformers,
29+
transformer_api,
30+
transformer_primitives,
3331
)
3432
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
3533

3634
if TYPE_CHECKING:
3735
import cirq
36+
import sympy
3837

3938

4039
@transformer_api.transformer
@@ -78,9 +77,9 @@ def merge_single_qubit_gates_to_phxz(
7877
circuit: cirq.AbstractCircuit,
7978
*,
8079
context: cirq.TransformerContext | None = None,
81-
merge_tags_fn: Callable[[cirq.CircuitOperation], List[Hashable]] | None = None,
80+
merge_tags_fn: Callable[[cirq.CircuitOperation], list[Hashable]] | None = None,
8281
atol: float = 1e-8,
83-
) -> 'cirq.Circuit':
82+
) -> cirq.Circuit:
8483
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
8584
8685
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
@@ -97,7 +96,7 @@ def merge_single_qubit_gates_to_phxz(
9796
Copy of the transformed input circuit.
9897
"""
9998

100-
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
99+
def rewriter(circuit_op: cirq.CircuitOperation) -> cirq.OP_TREE:
101100
u = protocols.unitary(circuit_op)
102101
if protocols.num_qubits(circuit_op) == 0:
103102
return ops.GlobalPhaseGate(u[0, 0]).on()
@@ -170,58 +169,43 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
170169
).unfreeze(copy=False)
171170

172171

173-
def _all_tags_startswith(circuit: cirq.AbstractCircuit, startswith: str):
174-
tag_set: set[Hashable] = set()
175-
for op in circuit.all_operations():
176-
for tag in op.tags:
177-
if str(tag).startswith(startswith):
178-
tag_set.add(tag)
179-
return tag_set
180-
181-
182172
def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep:
183-
new_resolvers: List[cirq.ParamResolver] = []
173+
new_resolvers: list[cirq.ParamResolver] = []
184174
for resolver in sweep:
185-
param_dict: 'cirq.ParamMappingType' = {s: resolver.value_of(s) for s in symbols}
175+
param_dict: cirq.ParamMappingType = {s: resolver.value_of(s) for s in symbols}
186176
new_resolvers.append(ParamResolver(param_dict))
187177
return ListSweep(new_resolvers)
188178

189179

190-
def _parameterize_phxz_in_circuits(
191-
circuit_list: List['cirq.Circuit'],
192-
merge_tag_prefix: str,
193-
phxz_symbols: set[sympy.Symbol],
194-
remaining_symbols: set[sympy.Symbol],
195-
sweep: Sweep,
180+
def _calc_phxz_sweeps(
181+
symbolized_circuit: cirq.Circuit, resolved_circuits: list[cirq.Circuit]
196182
) -> Sweep:
197-
"""Parameterizes the circuits and returns a new sweep."""
198-
values_by_params: dict[str, List[float]] = {**{str(s): [] for s in phxz_symbols}}
199-
200-
for circuit in circuit_list:
201-
for op in circuit.all_operations():
202-
the_merge_tag: str | None = None
203-
for tag in op.tags:
204-
if str(tag).startswith(merge_tag_prefix):
205-
the_merge_tag = str(tag)
206-
if not the_merge_tag:
207-
continue
208-
sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1]
209-
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
210-
if isinstance(op.gate, ops.PhasedXZGate):
211-
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
212-
elif op.gate is not ops.I:
213-
raise RuntimeError(
214-
f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
215-
f" but got {op.gate}."
183+
"""Return the phxz sweep of the symbolized_circuit on resolved_circuits.
184+
185+
Raises:
186+
ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
187+
Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
188+
symbolic `PhasedXZGate` in the `symbolized_circuit`.
189+
"""
190+
191+
def _extract_axz(op: ops.Operation) -> tuple[float, float, float]:
192+
if not op.gate or not isinstance(op.gate, ops.IdentityGate | ops.PhasedXZGate):
193+
raise ValueError(f"Expect a PhasedXZGate or IdentityGate on op {op}.")
194+
if isinstance(op.gate, ops.IdentityGate):
195+
return 0.0, 0.0, 0.0 # Identity gate's a, x, z in PhasedXZ
196+
phxz = cast(ops.PhasedXZGate, op.gate)
197+
return phxz.axis_phase_exponent, phxz.x_exponent, phxz.z_exponent
198+
199+
values_by_params: dict[str, list[float]] = {}
200+
for mid, moment in enumerate(symbolized_circuit):
201+
for op in moment.operations:
202+
if op.gate and isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op):
203+
sa, sx, sz = op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent
204+
values_by_params[sa], values_by_params[sx], values_by_params[sz] = zip(
205+
*[_extract_axz(c[mid].operation_at(op.qubits[0])) for c in resolved_circuits]
216206
)
217-
values_by_params[f"x{sid}"].append(x)
218-
values_by_params[f"z{sid}"].append(z)
219-
values_by_params[f"a{sid}"].append(a)
220207

221-
return Zip(
222-
dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)),
223-
_sweep_on_symbols(sweep, remaining_symbols),
224-
)
208+
return dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params))
225209

226210

227211
def merge_single_qubit_gates_to_phxz_symbolized(
@@ -230,7 +214,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
230214
context: cirq.TransformerContext | None = None,
231215
sweep: Sweep,
232216
atol: float = 1e-8,
233-
) -> Tuple[cirq.Circuit, Sweep]:
217+
) -> tuple[cirq.Circuit, Sweep]:
234218
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
235219
the consecutive gates is symbolized.
236220
@@ -288,6 +272,10 @@ def merge_single_qubit_gates_to_phxz_symbolized(
288272
for op in circuit_tagged.all_operations()
289273
]
290274
)
275+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
276+
remaining_symbols: set[sympy.Symbol] = set(
277+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
278+
)
291279
# If all single qubit gates are not parameterized, call the nonparamerized version of
292280
# the transformer.
293281
if not single_qubit_gate_symbols:
@@ -299,61 +287,43 @@ def merge_single_qubit_gates_to_phxz_symbolized(
299287
]
300288

301289
# Step 1, merge single qubit gates per resolved circuit, preserving
302-
# the symbolized_single_tag with indexes.
303-
merged_circuits: List['cirq.Circuit'] = []
304-
for resolved_circuit in resolved_circuits:
305-
merged_circuit = tag_transformers.index_tags(
306-
merge_single_qubit_gates_to_phxz(
307-
resolved_circuit,
308-
context=context,
309-
merge_tags_fn=lambda circuit_op: (
310-
[symbolized_single_tag]
311-
if any(
312-
symbolized_single_tag in set(op.tags)
313-
for op in circuit_op.circuit.all_operations()
314-
)
315-
else []
316-
),
317-
atol=atol,
290+
# the symbolized_single_tag to indicate the operator is a merged one.
291+
merged_circuits: list[cirq.Circuit] = [
292+
merge_single_qubit_gates_to_phxz(
293+
c,
294+
context=context,
295+
merge_tags_fn=lambda circuit_op: (
296+
[symbolized_single_tag]
297+
if any(
298+
symbolized_single_tag in set(op.tags)
299+
for op in circuit_op.circuit.all_operations()
300+
)
301+
else []
318302
),
319-
context=transformer_api.TransformerContext(deep=deep),
320-
target_tags={symbolized_single_tag},
303+
atol=atol,
321304
)
322-
merged_circuits.append(merged_circuit)
323-
324-
if not all(
325-
_all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
326-
== _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag)
327-
for merged_circuit in merged_circuits
328-
):
329-
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
305+
for c in resolved_circuits
306+
]
330307

331-
# Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag.
308+
# Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
332309
new_circuit = align.align_right(
333-
tag_transformers.remove_tags(
310+
tag_transformers.remove_tags( # remove the temp tags used to track merges
334311
symbolize.symbolize_single_qubit_gates_by_indexed_tags(
335-
merged_circuits[0],
312+
tag_transformers.index_tags( # index all single qubit ops merged from ops with symbols
313+
merged_circuits[0],
314+
context=transformer_api.TransformerContext(deep=deep),
315+
target_tags={symbolized_single_tag},
316+
),
336317
symbolize_tag=symbolize.SymbolizeTag(prefix=symbolized_single_tag),
337318
),
338319
remove_if=lambda tag: str(tag).startswith(symbolized_single_tag),
339320
)
340321
)
341322

342323
# Step 3, get N sets of parameterizations as new_sweep.
343-
phxz_symbols: set[sympy.Symbol] = set().union(
344-
*[
345-
set(
346-
[sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) for s in ["x", "z", "a"]]
347-
)
348-
for tag in _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
349-
]
350-
)
351-
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
352-
remaining_symbols: set[sympy.Symbol] = set(
353-
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
354-
)
355-
new_sweep = _parameterize_phxz_in_circuits(
356-
merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep
324+
new_sweep = Zip(
325+
_calc_phxz_sweeps(new_circuit, merged_circuits), # phxz sweeps
326+
_sweep_on_symbols(sweep, remaining_symbols), # remaining sweeps
357327
)
358328

359329
return new_circuit, new_sweep

cirq-core/cirq/transformers/merge_single_qubit_gates_test.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
from typing import List
16-
from unittest.mock import Mock, patch
1716
from unittest import TestCase
17+
from unittest.mock import Mock, patch
1818

1919
import pytest
2020
import sympy
@@ -24,7 +24,7 @@
2424

2525
def assert_optimizes(optimized: cirq.AbstractCircuit, expected: cirq.AbstractCircuit):
2626
# Ignore differences that would be caught by follow-up optimizations.
27-
followup_transformers: List[cirq.TRANSFORMER] = [
27+
followup_transformers: list[cirq.TRANSFORMER] = [
2828
cirq.drop_negligible_operations,
2929
cirq.drop_empty_moments,
3030
]
@@ -241,7 +241,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase():
241241
class TestMergeSingleQubitGatesSymbolized(TestCase):
242242
"""Test suite for merge_single_qubit_gates_to_phxz_symbolized."""
243243

244-
def case1(self):
244+
def test_case1(self):
245245
"""Test case diagram.
246246
Input circuit:
247247
# pylint: disable=line-too-long
@@ -298,7 +298,7 @@ def case1(self):
298298
{q: q for q in input_circuit.all_qubits()},
299299
)
300300

301-
def case_non_parameterized_singles(self):
301+
def test_case_non_parameterized_singles(self):
302302
"""Test merge_single_qubit_gates_to_phxz_symbolized when all single qubit gates are not
303303
parameterized."""
304304

@@ -310,27 +310,24 @@ def case_non_parameterized_singles(self):
310310
)
311311
assert_optimizes(output_circuit, expected_circuit)
312312

313-
def fail_different_structures_error(self):
314-
"""Tests that the function raises a RuntimeError if merged structures of the circuit differ
313+
def test_fail_different_structures_error(self):
314+
"""Tests that the function raises a ValueError if merged structures of the circuit differ
315315
for different parameterizations."""
316-
a = cirq.NamedQubit("a")
317-
circuit = cirq.Circuit(cirq.H(a) ** sympy.Symbol("exp"))
316+
q0, q1 = cirq.LineQubit.range(2)
317+
circuit = cirq.Circuit(cirq.H(q0) ** sympy.Symbol("exp"))
318318
sweep = cirq.Points(key="exp", points=[0.1, 0.2])
319319

320320
with patch(
321321
"cirq.protocols.resolve_parameters",
322-
side_effect=[
323-
cirq.Circuit(cirq.H(a).with_tags("_temp_symbolize_tag")),
324-
cirq.Circuit(cirq.H(a)),
322+
side_effect=[ # Mock the return values of resolve_parameters
323+
cirq.Circuit(cirq.I(q0).with_tags("_tmp_symbolize_tag")),
324+
cirq.Circuit(cirq.CZ(q0, q1)),
325325
],
326326
):
327-
with pytest.raises(
328-
RuntimeError,
329-
match="Different resolvers in sweep resulted in different merged structures.",
330-
):
327+
with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"):
331328
cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep)
332329

333-
def fail_unexpected_gate_error(self):
330+
def test_fail_unexpected_gate_error(self):
334331
"""Tests that the function raises a RuntimeError of unexpected gate."""
335332
a, b = cirq.LineQubit.range(2)
336333
circuit = cirq.Circuit(
@@ -352,7 +349,5 @@ def fail_unexpected_gate_error(self):
352349
".single_qubit_decompositions.single_qubit_matrix_to_phxz",
353350
return_value=cirq.H,
354351
):
355-
with pytest.raises(
356-
RuntimeError, match="Expected the merged gate to be a PhasedXZGate or IdentityGate."
357-
):
352+
with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"):
358353
cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep)

0 commit comments

Comments
 (0)