16
16
17
17
from __future__ import annotations
18
18
19
- from typing import Callable , cast , Hashable , List , Tuple , TYPE_CHECKING
20
-
21
- import sympy
19
+ from typing import Callable , cast , Hashable , TYPE_CHECKING
22
20
23
21
from cirq import circuits , ops , protocols
24
22
from cirq .study .resolver import ParamResolver
25
23
from cirq .study .sweeps import dict_to_zip_sweep , ListSweep , ProductOrZipSweepLike , Sweep , Zip
26
24
from cirq .transformers import (
27
25
align ,
28
26
merge_k_qubit_gates ,
29
- transformer_api ,
30
- transformer_primitives ,
31
27
symbolize ,
32
28
tag_transformers ,
29
+ transformer_api ,
30
+ transformer_primitives ,
33
31
)
34
32
from cirq .transformers .analytical_decompositions import single_qubit_decompositions
35
33
36
34
if TYPE_CHECKING :
37
35
import cirq
36
+ import sympy
38
37
39
38
40
39
@transformer_api .transformer
@@ -78,9 +77,9 @@ def merge_single_qubit_gates_to_phxz(
78
77
circuit : cirq .AbstractCircuit ,
79
78
* ,
80
79
context : cirq .TransformerContext | None = None ,
81
- merge_tags_fn : Callable [[cirq .CircuitOperation ], List [Hashable ]] | None = None ,
80
+ merge_tags_fn : Callable [[cirq .CircuitOperation ], list [Hashable ]] | None = None ,
82
81
atol : float = 1e-8 ,
83
- ) -> ' cirq.Circuit' :
82
+ ) -> cirq .Circuit :
84
83
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
85
84
86
85
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
@@ -97,7 +96,7 @@ def merge_single_qubit_gates_to_phxz(
97
96
Copy of the transformed input circuit.
98
97
"""
99
98
100
- def rewriter (circuit_op : ' cirq.CircuitOperation' ) -> ' cirq.OP_TREE' :
99
+ def rewriter (circuit_op : cirq .CircuitOperation ) -> cirq .OP_TREE :
101
100
u = protocols .unitary (circuit_op )
102
101
if protocols .num_qubits (circuit_op ) == 0 :
103
102
return ops .GlobalPhaseGate (u [0 , 0 ]).on ()
@@ -170,58 +169,43 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
170
169
).unfreeze (copy = False )
171
170
172
171
173
- def _all_tags_startswith (circuit : cirq .AbstractCircuit , startswith : str ):
174
- tag_set : set [Hashable ] = set ()
175
- for op in circuit .all_operations ():
176
- for tag in op .tags :
177
- if str (tag ).startswith (startswith ):
178
- tag_set .add (tag )
179
- return tag_set
180
-
181
-
182
172
def _sweep_on_symbols (sweep : Sweep , symbols : set [sympy .Symbol ]) -> Sweep :
183
- new_resolvers : List [cirq .ParamResolver ] = []
173
+ new_resolvers : list [cirq .ParamResolver ] = []
184
174
for resolver in sweep :
185
- param_dict : ' cirq.ParamMappingType' = {s : resolver .value_of (s ) for s in symbols }
175
+ param_dict : cirq .ParamMappingType = {s : resolver .value_of (s ) for s in symbols }
186
176
new_resolvers .append (ParamResolver (param_dict ))
187
177
return ListSweep (new_resolvers )
188
178
189
179
190
- def _parameterize_phxz_in_circuits (
191
- circuit_list : List ['cirq.Circuit' ],
192
- merge_tag_prefix : str ,
193
- phxz_symbols : set [sympy .Symbol ],
194
- remaining_symbols : set [sympy .Symbol ],
195
- sweep : Sweep ,
180
+ def _calc_phxz_sweeps (
181
+ symbolized_circuit : cirq .Circuit , resolved_circuits : list [cirq .Circuit ]
196
182
) -> Sweep :
197
- """Parameterizes the circuits and returns a new sweep."""
198
- values_by_params : dict [str , List [float ]] = {** {str (s ): [] for s in phxz_symbols }}
199
-
200
- for circuit in circuit_list :
201
- for op in circuit .all_operations ():
202
- the_merge_tag : str | None = None
203
- for tag in op .tags :
204
- if str (tag ).startswith (merge_tag_prefix ):
205
- the_merge_tag = str (tag )
206
- if not the_merge_tag :
207
- continue
208
- sid = the_merge_tag .rsplit ("_" , maxsplit = - 1 )[- 1 ]
209
- x , z , a = 0.0 , 0.0 , 0.0 # Identity gate's parameters
210
- if isinstance (op .gate , ops .PhasedXZGate ):
211
- x , z , a = op .gate .x_exponent , op .gate .z_exponent , op .gate .axis_phase_exponent
212
- elif op .gate is not ops .I :
213
- raise RuntimeError (
214
- f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
215
- f" but got { op .gate } ."
183
+ """Return the phxz sweep of the symbolized_circuit on resolved_circuits.
184
+
185
+ Raises:
186
+ ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
187
+ Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
188
+ symbolic `PhasedXZGate` in the `symbolized_circuit`.
189
+ """
190
+
191
+ def _extract_axz (op : ops .Operation ) -> tuple [float , float , float ]:
192
+ if not op .gate or not isinstance (op .gate , ops .IdentityGate | ops .PhasedXZGate ):
193
+ raise ValueError (f"Expect a PhasedXZGate or IdentityGate on op { op } ." )
194
+ if isinstance (op .gate , ops .IdentityGate ):
195
+ return 0.0 , 0.0 , 0.0 # Identity gate's a, x, z in PhasedXZ
196
+ phxz = cast (ops .PhasedXZGate , op .gate )
197
+ return phxz .axis_phase_exponent , phxz .x_exponent , phxz .z_exponent
198
+
199
+ values_by_params : dict [str , list [float ]] = {}
200
+ for mid , moment in enumerate (symbolized_circuit ):
201
+ for op in moment .operations :
202
+ if op .gate and isinstance (op .gate , ops .PhasedXZGate ) and protocols .is_parameterized (op ):
203
+ sa , sx , sz = op .gate .axis_phase_exponent , op .gate .x_exponent , op .gate .z_exponent
204
+ values_by_params [sa ], values_by_params [sx ], values_by_params [sz ] = zip (
205
+ * [_extract_axz (c [mid ].operation_at (op .qubits [0 ])) for c in resolved_circuits ]
216
206
)
217
- values_by_params [f"x{ sid } " ].append (x )
218
- values_by_params [f"z{ sid } " ].append (z )
219
- values_by_params [f"a{ sid } " ].append (a )
220
207
221
- return Zip (
222
- dict_to_zip_sweep (cast (ProductOrZipSweepLike , values_by_params )),
223
- _sweep_on_symbols (sweep , remaining_symbols ),
224
- )
208
+ return dict_to_zip_sweep (cast (ProductOrZipSweepLike , values_by_params ))
225
209
226
210
227
211
def merge_single_qubit_gates_to_phxz_symbolized (
@@ -230,7 +214,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
230
214
context : cirq .TransformerContext | None = None ,
231
215
sweep : Sweep ,
232
216
atol : float = 1e-8 ,
233
- ) -> Tuple [cirq .Circuit , Sweep ]:
217
+ ) -> tuple [cirq .Circuit , Sweep ]:
234
218
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
235
219
the consecutive gates is symbolized.
236
220
@@ -288,6 +272,10 @@ def merge_single_qubit_gates_to_phxz_symbolized(
288
272
for op in circuit_tagged .all_operations ()
289
273
]
290
274
)
275
+ # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
276
+ remaining_symbols : set [sympy .Symbol ] = set (
277
+ protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
278
+ )
291
279
# If all single qubit gates are not parameterized, call the nonparamerized version of
292
280
# the transformer.
293
281
if not single_qubit_gate_symbols :
@@ -299,61 +287,43 @@ def merge_single_qubit_gates_to_phxz_symbolized(
299
287
]
300
288
301
289
# Step 1, merge single qubit gates per resolved circuit, preserving
302
- # the symbolized_single_tag with indexes.
303
- merged_circuits : List ['cirq.Circuit' ] = []
304
- for resolved_circuit in resolved_circuits :
305
- merged_circuit = tag_transformers .index_tags (
306
- merge_single_qubit_gates_to_phxz (
307
- resolved_circuit ,
308
- context = context ,
309
- merge_tags_fn = lambda circuit_op : (
310
- [symbolized_single_tag ]
311
- if any (
312
- symbolized_single_tag in set (op .tags )
313
- for op in circuit_op .circuit .all_operations ()
314
- )
315
- else []
316
- ),
317
- atol = atol ,
290
+ # the symbolized_single_tag to indicate the operator is a merged one.
291
+ merged_circuits : list [cirq .Circuit ] = [
292
+ merge_single_qubit_gates_to_phxz (
293
+ c ,
294
+ context = context ,
295
+ merge_tags_fn = lambda circuit_op : (
296
+ [symbolized_single_tag ]
297
+ if any (
298
+ symbolized_single_tag in set (op .tags )
299
+ for op in circuit_op .circuit .all_operations ()
300
+ )
301
+ else []
318
302
),
319
- context = transformer_api .TransformerContext (deep = deep ),
320
- target_tags = {symbolized_single_tag },
303
+ atol = atol ,
321
304
)
322
- merged_circuits .append (merged_circuit )
323
-
324
- if not all (
325
- _all_tags_startswith (merged_circuits [0 ], startswith = symbolized_single_tag )
326
- == _all_tags_startswith (merged_circuit , startswith = symbolized_single_tag )
327
- for merged_circuit in merged_circuits
328
- ):
329
- raise RuntimeError ("Different resolvers in sweep resulted in different merged structures." )
305
+ for c in resolved_circuits
306
+ ]
330
307
331
- # Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag.
308
+ # Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
332
309
new_circuit = align .align_right (
333
- tag_transformers .remove_tags (
310
+ tag_transformers .remove_tags ( # remove the temp tags used to track merges
334
311
symbolize .symbolize_single_qubit_gates_by_indexed_tags (
335
- merged_circuits [0 ],
312
+ tag_transformers .index_tags ( # index all single qubit ops merged from ops with symbols
313
+ merged_circuits [0 ],
314
+ context = transformer_api .TransformerContext (deep = deep ),
315
+ target_tags = {symbolized_single_tag },
316
+ ),
336
317
symbolize_tag = symbolize .SymbolizeTag (prefix = symbolized_single_tag ),
337
318
),
338
319
remove_if = lambda tag : str (tag ).startswith (symbolized_single_tag ),
339
320
)
340
321
)
341
322
342
323
# Step 3, get N sets of parameterizations as new_sweep.
343
- phxz_symbols : set [sympy .Symbol ] = set ().union (
344
- * [
345
- set (
346
- [sympy .Symbol (tag .replace (f"{ symbolized_single_tag } _" , s )) for s in ["x" , "z" , "a" ]]
347
- )
348
- for tag in _all_tags_startswith (merged_circuits [0 ], startswith = symbolized_single_tag )
349
- ]
350
- )
351
- # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
352
- remaining_symbols : set [sympy .Symbol ] = set (
353
- protocols .parameter_symbols (circuit ) - single_qubit_gate_symbols
354
- )
355
- new_sweep = _parameterize_phxz_in_circuits (
356
- merged_circuits , symbolized_single_tag , phxz_symbols , remaining_symbols , sweep
324
+ new_sweep = Zip (
325
+ _calc_phxz_sweeps (new_circuit , merged_circuits ), # phxz sweeps
326
+ _sweep_on_symbols (sweep , remaining_symbols ), # remaining sweeps
357
327
)
358
328
359
329
return new_circuit , new_sweep
0 commit comments