16
16
17
17
from __future__ import annotations
18
18
19
- from typing import TYPE_CHECKING
19
+ from typing import Callable , cast , Hashable , TYPE_CHECKING
20
20
21
21
from cirq import circuits , ops , protocols
22
- from cirq .transformers import merge_k_qubit_gates , transformer_api , transformer_primitives
22
+ from cirq .study .resolver import ParamResolver
23
+ from cirq .study .sweeps import dict_to_zip_sweep , ListSweep , ProductOrZipSweepLike , Sweep , Zip
24
+ from cirq .transformers import (
25
+ align ,
26
+ merge_k_qubit_gates ,
27
+ symbolize ,
28
+ tag_transformers ,
29
+ transformer_api ,
30
+ transformer_primitives ,
31
+ )
23
32
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
24
33
25
34
if TYPE_CHECKING :
35
+ import sympy
36
+
26
37
import cirq
27
38
28
39
@@ -67,6 +78,7 @@ def merge_single_qubit_gates_to_phxz(
67
78
circuit : cirq .AbstractCircuit ,
68
79
* ,
69
80
context : cirq .TransformerContext | None = None ,
81
+ merge_tags_fn : Callable [[cirq .CircuitOperation ], list [Hashable ]] | None = None ,
70
82
atol : float = 1e-8 ,
71
83
) -> cirq .Circuit :
72
84
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
@@ -77,19 +89,21 @@ def merge_single_qubit_gates_to_phxz(
77
89
Args:
78
90
circuit: Input circuit to transform. It will not be modified.
79
91
context: `cirq.TransformerContext` storing common configurable options for transformers.
92
+ merge_tags_fn: A callable returns the tags to be added to the merged operation.
80
93
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
81
94
dropped, smaller values increase accuracy.
82
95
83
96
Returns:
84
97
Copy of the transformed input circuit.
85
98
"""
86
99
87
- def rewriter (op : cirq .CircuitOperation ) -> cirq .OP_TREE :
88
- u = protocols .unitary (op )
89
- if protocols .num_qubits (op ) == 0 :
100
+ def rewriter (circuit_op : cirq .CircuitOperation ) -> cirq .OP_TREE :
101
+ u = protocols .unitary (circuit_op )
102
+ if protocols .num_qubits (circuit_op ) == 0 :
90
103
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
91
- gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol )
92
- return gate (op .qubits [0 ]) if gate else []
104
+ gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
105
+ phxz_op = gate .on (circuit_op .qubits [0 ])
106
+ return phxz_op .with_tags (* merge_tags_fn (circuit_op )) if merge_tags_fn else phxz_op
93
107
94
108
return merge_k_qubit_gates .merge_k_qubit_unitaries (
95
109
circuit , k = 1 , context = context , rewriter = rewriter
@@ -158,3 +172,160 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
158
172
deep = context .deep if context else False ,
159
173
tags_to_ignore = tuple (tags_to_ignore ),
160
174
).unfreeze (copy = False )
175
+
176
+
177
+ def _sweep_on_symbols (sweep : Sweep , symbols : set [sympy .Symbol ]) -> Sweep :
178
+ new_resolvers : list [cirq .ParamResolver ] = []
179
+ for resolver in sweep :
180
+ param_dict : cirq .ParamMappingType = {s : resolver .value_of (s ) for s in symbols }
181
+ new_resolvers .append (ParamResolver (param_dict ))
182
+ return ListSweep (new_resolvers )
183
+
184
+
185
+ def _calc_phxz_sweeps (
186
+ symbolized_circuit : cirq .Circuit , resolved_circuits : list [cirq .Circuit ]
187
+ ) -> Sweep :
188
+ """Return the phxz sweep of the symbolized_circuit on resolved_circuits.
189
+
190
+ Raises:
191
+ ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
192
+ Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
193
+ symbolic `PhasedXZGate` in the `symbolized_circuit`.
194
+ """
195
+
196
+ def _extract_axz (op : ops .Operation | None ) -> tuple [float , float , float ]:
197
+ if not op or not op .gate or not isinstance (op .gate , ops .IdentityGate | ops .PhasedXZGate ):
198
+ raise ValueError (f"Expect a PhasedXZGate or IdentityGate on op { op } ." )
199
+ if isinstance (op .gate , ops .IdentityGate ):
200
+ return 0.0 , 0.0 , 0.0 # Identity gate's a, x, z in PhasedXZ
201
+ return op .gate .axis_phase_exponent , op .gate .x_exponent , op .gate .z_exponent
202
+
203
+ values_by_params : dict [sympy .Symbol , tuple [float , ...]] = {}
204
+ for mid , moment in enumerate (symbolized_circuit ):
205
+ for op in moment .operations :
206
+ if op .gate and isinstance (op .gate , ops .PhasedXZGate ) and protocols .is_parameterized (op ):
207
+ sa , sx , sz = op .gate .axis_phase_exponent , op .gate .x_exponent , op .gate .z_exponent
208
+ values_by_params [sa ], values_by_params [sx ], values_by_params [sz ] = zip (
209
+ * [_extract_axz (c [mid ].operation_at (op .qubits [0 ])) for c in resolved_circuits ]
210
+ )
211
+
212
+ return dict_to_zip_sweep (cast (ProductOrZipSweepLike , values_by_params ))
213
+
214
+
215
+ def merge_single_qubit_gates_to_phxz_symbolized (
216
+ circuit : cirq .AbstractCircuit ,
217
+ * ,
218
+ context : cirq .TransformerContext | None = None ,
219
+ sweep : Sweep ,
220
+ atol : float = 1e-8 ,
221
+ ) -> tuple [cirq .Circuit , Sweep ]:
222
+ """Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
223
+ the consecutive gates is symbolized.
224
+
225
+ Example:
226
+ >>> q0, q1 = cirq.LineQubit.range(2)
227
+ >>> c = cirq.Circuit(\
228
+ cirq.X(q0),\
229
+ cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
230
+ cirq.Y(q0)**sympy.Symbol("y_exp"),\
231
+ cirq.X(q0))
232
+ >>> print(c)
233
+ 0: ───X───@──────────Y^y_exp───X───
234
+ │
235
+ 1: ───────@^cz_exp─────────────────
236
+ >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
237
+ c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
238
+ cirq.Points(key="y_exp", points=[0, 1])))
239
+ >>> print(new_circuit)
240
+ 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
241
+ │
242
+ 1: ────────────────────────@^cz_exp──────────────────────────
243
+ >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
244
+ >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
245
+
246
+ Args:
247
+ circuit: Input circuit to transform. It will not be modified.
248
+ context: `cirq.TransformerContext` storing common configurable options for transformers.
249
+ sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
250
+ based on the transformation.
251
+ atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
252
+ dropped, smaller values increase accuracy.
253
+
254
+ Returns:
255
+ Copy of the transformed input circuit.
256
+ """
257
+ deep = context .deep if context else False
258
+
259
+ # Tag symbolized single-qubit op.
260
+ symbolized_single_tag = "_tmp_symbolize_tag"
261
+
262
+ circuit_tagged = transformer_primitives .map_operations (
263
+ circuit ,
264
+ lambda op , _ : (
265
+ op .with_tags (symbolized_single_tag )
266
+ if protocols .is_parameterized (op ) and len (op .qubits ) == 1
267
+ else op
268
+ ),
269
+ deep = deep ,
270
+ )
271
+
272
+ # Step 0, isolate single qubit symbols and resolve the circuit on them.
273
+ single_qubit_gate_symbols : set [sympy .Symbol ] = set ().union (
274
+ * [
275
+ protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
276
+ for op in circuit_tagged .all_operations ()
277
+ ]
278
+ )
279
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
280
+ remaining_symbols : set [sympy .Symbol ] = set (
281
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
282
+ )
283
+ # If all single qubit gates are not parameterized, call the nonparamerized version of
284
+ # the transformer.
285
+ if not single_qubit_gate_symbols :
286
+ return (merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep )
287
+ sweep_of_single : Sweep = _sweep_on_symbols (sweep , single_qubit_gate_symbols )
288
+ # Get all resolved circuits from all sets of resolvers in sweep_of_single.
289
+ resolved_circuits = [
290
+ protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
291
+ ]
292
+
293
+ # Step 1, merge single qubit gates per resolved circuit, preserving
294
+ # the symbolized_single_tag to indicate the operator is a merged one.
295
+ merged_circuits : list [cirq .Circuit ] = [
296
+ merge_single_qubit_gates_to_phxz (
297
+ c ,
298
+ context = context ,
299
+ merge_tags_fn = lambda circuit_op : (
300
+ [symbolized_single_tag ]
301
+ if any (
302
+ symbolized_single_tag in set (op .tags )
303
+ for op in circuit_op .circuit .all_operations ()
304
+ )
305
+ else []
306
+ ),
307
+ atol = atol ,
308
+ )
309
+ for c in resolved_circuits
310
+ ]
311
+
312
+ # Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
313
+ new_circuit = tag_transformers .remove_tags ( # remove the temp tags used to track merges
314
+ symbolize .symbolize_single_qubit_gates_by_indexed_tags (
315
+ tag_transformers .index_tags ( # index all 1-qubit-ops merged from ops with symbols
316
+ merged_circuits [0 ],
317
+ context = transformer_api .TransformerContext (deep = deep ),
318
+ target_tags = {symbolized_single_tag },
319
+ ),
320
+ symbolize_tag = symbolize .SymbolizeTag (prefix = symbolized_single_tag ),
321
+ ),
322
+ remove_if = lambda tag : str (tag ).startswith (symbolized_single_tag ),
323
+ )
324
+
325
+ # Step 3, get N sets of parameterizations as new_sweep.
326
+ new_sweep = Zip (
327
+ _calc_phxz_sweeps (new_circuit , merged_circuits ), # phxz sweeps
328
+ _sweep_on_symbols (sweep , remaining_symbols ), # remaining sweeps
329
+ )
330
+
331
+ return align .align_right (new_circuit ), new_sweep
0 commit comments