16
16
17
17
from __future__ import annotations
18
18
19
- from typing import TYPE_CHECKING
19
+ from typing import Callable , cast , Hashable , List , Tuple , TYPE_CHECKING
20
+
21
+ import sympy
20
22
21
23
from cirq import circuits , ops , protocols
22
- from cirq .transformers import merge_k_qubit_gates , transformer_api , transformer_primitives
24
+ from cirq .study .resolver import ParamResolver
25
+ from cirq .study .sweeps import dict_to_zip_sweep , ListSweep , ProductOrZipSweepLike , Sweep , Zip
26
+ from cirq .transformers import (
27
+ align ,
28
+ merge_k_qubit_gates ,
29
+ transformer_api ,
30
+ transformer_primitives ,
31
+ symbolize ,
32
+ tag_transformers ,
33
+ )
23
34
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
24
35
25
36
if TYPE_CHECKING :
@@ -67,8 +78,9 @@ def merge_single_qubit_gates_to_phxz(
67
78
circuit : cirq .AbstractCircuit ,
68
79
* ,
69
80
context : cirq .TransformerContext | None = None ,
81
+ merge_tags_fn : Callable [[cirq .CircuitOperation ], List [Hashable ]] | None = None ,
70
82
atol : float = 1e-8 ,
71
- ) -> cirq .Circuit :
83
+ ) -> ' cirq.Circuit' :
72
84
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
73
85
74
86
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
@@ -77,19 +89,21 @@ def merge_single_qubit_gates_to_phxz(
77
89
Args:
78
90
circuit: Input circuit to transform. It will not be modified.
79
91
context: `cirq.TransformerContext` storing common configurable options for transformers.
92
+ merge_tags_fn: A callable returns the tags to be added to the merged operation.
80
93
atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
81
94
dropped, smaller values increase accuracy.
82
95
83
96
Returns:
84
97
Copy of the transformed input circuit.
85
98
"""
86
99
87
- def rewriter (op : cirq .CircuitOperation ) -> cirq .OP_TREE :
88
- u = protocols .unitary (op )
89
- if protocols .num_qubits (op ) == 0 :
100
+ def rewriter (circuit_op : ' cirq.CircuitOperation' ) -> ' cirq.OP_TREE' :
101
+ u = protocols .unitary (circuit_op )
102
+ if protocols .num_qubits (circuit_op ) == 0 :
90
103
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
91
- gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol )
92
- return gate (op .qubits [0 ]) if gate else []
104
+ gate = single_qubit_decompositions .single_qubit_matrix_to_phxz (u , atol ) or ops .I
105
+ phxz_op = gate .on (circuit_op .qubits [0 ])
106
+ return phxz_op .with_tags (* merge_tags_fn (circuit_op )) if merge_tags_fn else phxz_op
93
107
94
108
return merge_k_qubit_gates .merge_k_qubit_unitaries (
95
109
circuit , k = 1 , context = context , rewriter = rewriter
@@ -154,3 +168,192 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
154
168
deep = context .deep if context else False ,
155
169
tags_to_ignore = tuple (tags_to_ignore ),
156
170
).unfreeze (copy = False )
171
+
172
+
173
+ def _all_tags_startswith (circuit : cirq .AbstractCircuit , startswith : str ):
174
+ tag_set : set [Hashable ] = set ()
175
+ for op in circuit .all_operations ():
176
+ for tag in op .tags :
177
+ if str (tag ).startswith (startswith ):
178
+ tag_set .add (tag )
179
+ return tag_set
180
+
181
+
182
+ def _sweep_on_symbols (sweep : Sweep , symbols : set [sympy .Symbol ]) -> Sweep :
183
+ new_resolvers : List [cirq .ParamResolver ] = []
184
+ for resolver in sweep :
185
+ param_dict : 'cirq.ParamMappingType' = {s : resolver .value_of (s ) for s in symbols }
186
+ new_resolvers .append (ParamResolver (param_dict ))
187
+ return ListSweep (new_resolvers )
188
+
189
+
190
+ def _parameterize_phxz_in_circuits (
191
+ circuit_list : List ['cirq.Circuit' ],
192
+ merge_tag_prefix : str ,
193
+ phxz_symbols : set [sympy .Symbol ],
194
+ remaining_symbols : set [sympy .Symbol ],
195
+ sweep : Sweep ,
196
+ ) -> Sweep :
197
+ """Parameterizes the circuits and returns a new sweep."""
198
+ values_by_params : dict [str , List [float ]] = {** {str (s ): [] for s in phxz_symbols }}
199
+
200
+ for circuit in circuit_list :
201
+ for op in circuit .all_operations ():
202
+ the_merge_tag : str | None = None
203
+ for tag in op .tags :
204
+ if str (tag ).startswith (merge_tag_prefix ):
205
+ the_merge_tag = str (tag )
206
+ if not the_merge_tag :
207
+ continue
208
+ sid = the_merge_tag .rsplit ("_" , maxsplit = - 1 )[- 1 ]
209
+ x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters
210
+ if isinstance (op .gate , ops .PhasedXZGate ):
211
+ x , z , a = op .gate .x_exponent , op .gate .z_exponent , op .gate .axis_phase_exponent
212
+ elif op .gate is not ops .I :
213
+ raise RuntimeError (
214
+ f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
215
+ f" but got { op .gate } ."
216
+ )
217
+ values_by_params [f"x{ sid } " ].append (x )
218
+ values_by_params [f"z{ sid } " ].append (z )
219
+ values_by_params [f"a{ sid } " ].append (a )
220
+
221
+ return Zip (
222
+ dict_to_zip_sweep (cast (ProductOrZipSweepLike , values_by_params )),
223
+ _sweep_on_symbols (sweep , remaining_symbols ),
224
+ )
225
+
226
+
227
+ def merge_single_qubit_gates_to_phxz_symbolized (
228
+ circuit : cirq .AbstractCircuit ,
229
+ * ,
230
+ context : cirq .TransformerContext | None = None ,
231
+ sweep : Sweep ,
232
+ atol : float = 1e-8 ,
233
+ ) -> Tuple [cirq .Circuit , Sweep ]:
234
+ """Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
235
+ the consecutive gates is symbolized.
236
+
237
+ Example:
238
+ >>> q0, q1 = cirq.LineQubit.range(2)
239
+ >>> c = cirq.Circuit(\
240
+ cirq.X(q0),\
241
+ cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\
242
+ cirq.Y(q0)**sympy.Symbol("y_exp"),\
243
+ cirq.X(q0))
244
+ >>> print(c)
245
+ 0: ───X───@──────────Y^y_exp───X───
246
+ │
247
+ 1: ───────@^cz_exp─────────────────
248
+ >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\
249
+ c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\
250
+ cirq.Points(key="y_exp", points=[0, 1])))
251
+ >>> print(new_circuit)
252
+ 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)───
253
+ │
254
+ 1: ────────────────────────@^cz_exp──────────────────────────
255
+ >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0})
256
+ >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1})
257
+
258
+ Args:
259
+ circuit: Input circuit to transform. It will not be modified.
260
+ context: `cirq.TransformerContext` storing common configurable options for transformers.
261
+ sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned
262
+ based on the transformation.
263
+ atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be
264
+ dropped, smaller values increase accuracy.
265
+
266
+ Returns:
267
+ Copy of the transformed input circuit.
268
+ """
269
+ deep = context .deep if context else False
270
+
271
+ # Tag symbolized single-qubit op.
272
+ symbolized_single_tag = "_tmp_symbolize_tag"
273
+
274
+ circuit_tagged = transformer_primitives .map_operations (
275
+ circuit ,
276
+ lambda op , _ : (
277
+ op .with_tags (symbolized_single_tag )
278
+ if protocols .is_parameterized (op ) and len (op .qubits ) == 1
279
+ else op
280
+ ),
281
+ deep = deep ,
282
+ )
283
+
284
+ # Step 0, isolate single qubit symbols and resolve the circuit on them.
285
+ single_qubit_gate_symbols : set [sympy .Symbol ] = set ().union (
286
+ * [
287
+ protocols .parameter_symbols (op ) if symbolized_single_tag in op .tags else set ()
288
+ for op in circuit_tagged .all_operations ()
289
+ ]
290
+ )
291
+ # If all single qubit gates are not parameterized, call the nonparamerized version of
292
+ # the transformer.
293
+ if not single_qubit_gate_symbols :
294
+ return (merge_single_qubit_gates_to_phxz (circuit , context = context , atol = atol ), sweep )
295
+ sweep_of_single : Sweep = _sweep_on_symbols (sweep , single_qubit_gate_symbols )
296
+ # Get all resolved circuits from all sets of resolvers in sweep_of_single.
297
+ resolved_circuits = [
298
+ protocols .resolve_parameters (circuit_tagged , resolver ) for resolver in sweep_of_single
299
+ ]
300
+
301
+ # Step 1, merge single qubit gates per resolved circuit, preserving
302
+ # the symbolized_single_tag with indexes.
303
+ merged_circuits : List ['cirq.Circuit' ] = []
304
+ for resolved_circuit in resolved_circuits :
305
+ merged_circuit = tag_transformers .index_tags (
306
+ merge_single_qubit_gates_to_phxz (
307
+ resolved_circuit ,
308
+ context = context ,
309
+ merge_tags_fn = lambda circuit_op : (
310
+ [symbolized_single_tag ]
311
+ if any (
312
+ symbolized_single_tag in set (op .tags )
313
+ for op in circuit_op .circuit .all_operations ()
314
+ )
315
+ else []
316
+ ),
317
+ atol = atol ,
318
+ ),
319
+ context = transformer_api .TransformerContext (deep = deep ),
320
+ target_tags = {symbolized_single_tag },
321
+ )
322
+ merged_circuits .append (merged_circuit )
323
+
324
+ if not all (
325
+ _all_tags_startswith (merged_circuits [0 ], startswith = symbolized_single_tag )
326
+ == _all_tags_startswith (merged_circuit , startswith = symbolized_single_tag )
327
+ for merged_circuit in merged_circuits
328
+ ):
329
+ raise RuntimeError ("Different resolvers in sweep resulted in different merged structures." )
330
+
331
+ # Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag.
332
+ new_circuit = align .align_right (
333
+ tag_transformers .remove_tags (
334
+ symbolize .symbolize_single_qubit_gates_by_indexed_tags (
335
+ merged_circuits [0 ],
336
+ symbolize_tag = symbolize .SymbolizeTag (prefix = symbolized_single_tag ),
337
+ ),
338
+ remove_if = lambda tag : str (tag ).startswith (symbolized_single_tag ),
339
+ )
340
+ )
341
+
342
+ # Step 3, get N sets of parameterizations as new_sweep.
343
+ phxz_symbols : set [sympy .Symbol ] = set ().union (
344
+ * [
345
+ set (
346
+ [sympy .Symbol (tag .replace (f"{ symbolized_single_tag } _" , s )) for s in ["x" , "z" , "a" ]]
347
+ )
348
+ for tag in _all_tags_startswith (merged_circuits [0 ], startswith = symbolized_single_tag )
349
+ ]
350
+ )
351
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
352
+ remaining_symbols : set [sympy .Symbol ] = set (
353
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
354
+ )
355
+ new_sweep = _parameterize_phxz_in_circuits (
356
+ merged_circuits , symbolized_single_tag , phxz_symbols , remaining_symbols , sweep
357
+ )
358
+
359
+ return new_circuit , new_sweep
0 commit comments