1616
1717from __future__ import annotations
1818
19- from typing import List
19+ from typing import cast , List
2020
2121import numpy as np
2222
@@ -156,12 +156,12 @@ class _PhasedXYAndRz:
156156 The order is --(X|Y|I)--Rz(rad)--phase--.
157157 """
158158
159- pauli : ops .X | ops .Y | ops . I
159+ pauli : ops .Pauli | ops .IdentityGate # X|Y| I
160160 rz_rads : float
161161 phase_exp : float # phase of the qubit is e^{i*phase_exp*pi}
162162
163163 def __init__ (
164- self , pauli : ops .Pauli | ops .I = ops .I , rz_rads : float = 0 , phase_exp : float = 0
164+ self , pauli : ops .Pauli | ops .IdentityGate = ops .I , rz_rads : float = 0 , phase_exp : float = 0
165165 ) -> None :
166166 if pauli == ops .Z : # Merge Z gates to Rz where Z = Rz(π) * e^{iπ/2}
167167 self .pauli = ops .I
@@ -183,7 +183,7 @@ def _merge_right_rz(self, rads: float):
183183 """Merges Rz(rads) from right."""
184184 self .rz_rads += rads
185185
186- def _merge_left_xy (self , other : ops .X | ops . Y ):
186+ def _merge_left_xy (self , other : ops .Pauli ):
187187 """Merges --(X|Y)--self--."""
188188 if self .pauli == other :
189189 self .pauli = ops .I
@@ -202,7 +202,7 @@ def _merge_left_xy(self, other: ops.X | ops.Y):
202202 self .rz_rads -= np .pi
203203 return
204204
205- def _merge_right_xy (self , other : ops .X | ops . Y ):
205+ def _merge_right_xy (self , other : ops .Pauli ):
206206 """Merges --self--(X|Y)--."""
207207 self .rz_rads *= - 1
208208 if self .pauli == other :
@@ -227,13 +227,13 @@ def merge_left(self, other: _PhasedXYAndRz) -> None:
227227 self ._merge_left_rz (other .rz_rads )
228228 self .phase_exp += other .phase_exp
229229 if other .pauli != ops .I :
230- self ._merge_left_xy (other .pauli )
230+ self ._merge_left_xy (cast ( ops . Pauli , other .pauli ) )
231231
232232 def merge_right (self , other : _PhasedXYAndRz ) -> None :
233233 """Inplace merge other from right."""
234234 self .phase_exp += other .phase_exp
235235 if other .pauli != ops .I :
236- self ._merge_right_xy (other .pauli )
236+ self ._merge_right_xy (cast ( ops . Pauli , other .pauli ) )
237237 self ._merge_right_rz (other .rz_rads )
238238
239239 def after_cphase (
@@ -259,41 +259,47 @@ def after_cphase(
259259 # Similarly for X|Y on qubit 0/1, the result is always flipping cphase and
260260 # add an extra Rz rotation on the other qubit.
261261 return (
262- cphase ** - 1 ,
262+ cast ( ops . CZPowGate , cphase ** - 1 ) ,
263263 self ,
264264 _PhasedXYAndRz (rz_rads = cphase .exponent * np .pi , phase_exp = - cphase .exponent / 2 ),
265265 )
266266
267267 def __str__ (self ) -> str :
268268 return f"─{ self .pauli } ──Rz({ self .rz_rads } )──phase(e^{{i{ self .phase_exp } π}})─"
269269
270- def __eq__ (self , other : _PhasedXYAndRz ) -> bool :
270+ def __eq__ (self , other : object ) -> bool :
271+ if not isinstance (other , _PhasedXYAndRz ):
272+ raise NotImplementedError
271273 return (
272274 self .pauli == other .pauli
273- and np .isclose (self .rz_rads , other .rz_rads , atol = 1e-10 )
274- and np .isclose (self .phase_exp , other .phase_exp , atol = 1e-10 )
275+ and bool ( np .isclose (self .rz_rads , other .rz_rads , atol = 1e-10 ) )
276+ and bool ( np .isclose (self .phase_exp , other .phase_exp , atol = 1e-10 ) )
275277 )
276278
277279 def to_single_gate (self ) -> ops .PhasedXZGate | ops .ZPowGate :
278- if self .pauli == ops .I :
279- rz_rads = self .rz_rads
280- if np .isclose (self .rz_rads , 0 , atol = 1e-2 ):
281- rz_rads = self .rz_rads + 4 * np .pi
282- return ops .ZPowGate (
283- exponent = rz_rads / np .pi , global_shift = np .pi * self .phase_exp / rz_rads - 0.5
284- )
285- if self .pauli == ops .X :
286- return ops .PhasedXZGate (
287- x_exponent = 1 ,
288- z_exponent = 2 * self .phase_exp ,
289- axis_phase_exponent = self .rz_rads / 2 / np .pi - self .phase_exp ,
290- )
291- if self .pauli == ops .Y :
292- return ops .PhasedXZGate (
293- x_exponent = 1 ,
294- z_exponent = 2 * self .phase_exp ,
295- axis_phase_exponent = 1 / 2 - self .phase_exp + self .rz_rads / 2 / np .pi ,
296- )
280+ """Converts the _PhasedXYAndRz to a single-qubit gate."""
281+ match self .pauli :
282+ case ops .I :
283+ rz_rads = self .rz_rads
284+ if np .isclose (self .rz_rads , 0 , atol = 1e-2 ):
285+ rz_rads = self .rz_rads + 4 * np .pi
286+ return ops .ZPowGate (
287+ exponent = rz_rads / np .pi , global_shift = np .pi * self .phase_exp / rz_rads - 0.5
288+ )
289+ case ops .X :
290+ return ops .PhasedXZGate (
291+ x_exponent = 1 ,
292+ z_exponent = 2 * self .phase_exp ,
293+ axis_phase_exponent = self .rz_rads / 2 / np .pi - self .phase_exp ,
294+ )
295+ case ops .Y :
296+ return ops .PhasedXZGate (
297+ x_exponent = 1 ,
298+ z_exponent = 2 * self .phase_exp ,
299+ axis_phase_exponent = 1 / 2 - self .phase_exp + self .rz_rads / 2 / np .pi ,
300+ )
301+ case _:
302+ raise ValueError ("Invalid self.pauli." )
297303
298304
299305def _pull_through_single_cphase (
@@ -327,26 +333,31 @@ def _pull_through_single_cphase(
327333 return output_cphase , output0 , output1
328334
329335
330- def _multi_moment_pull_through (
336+ def _multi_moment_gauge_fn (
331337 moments : List [circuits .Moment ], rng : np .random .Generator
332338) -> List [circuits .Moment ]:
333- """TO FILL ."""
339+ """Generates a left layer with random generator, then pulling through all the moments ."""
334340 all_qubits = [q for q in circuits .Circuit (moments ).all_qubits ()]
335341 if not all_qubits :
336342 return moments
337343 if not any (isinstance (op .gate , ops .CZPowGate ) for moment in moments for op in moment ):
338344 return moments
339345
340- left_moment = circuits .Moment (
341- [rng .choice ([ops .I , ops .X , ops .Y , ops .Z ]).on (q ) for q in all_qubits ]
342- )
343- prev : map [ops .Qid , ops .Gate ] = {
344- op .qubits [0 ]: _PhasedXYAndRz (pauli = op .gate ) for op in left_moment
346+ left_moment : List [ops .Operation ] = [
347+ rng .choice (
348+ np .array ([ops .I , ops .X , ops .Y , ops .Z ], dtype = ops .Gate ), p = [0.25 , 0.25 , 0.25 , 0.25 ]
349+ ).on (q )
350+ for q in all_qubits
351+ ]
352+ prev : dict [ops .Qid , _PhasedXYAndRz ] = {
353+ op .qubits [0 ]: _PhasedXYAndRz (pauli = cast (ops .Pauli | ops .IdentityGate , op .gate ))
354+ for op in left_moment
355+ if op .gate
345356 }
346357
347- new_moments : List [circuits .Moment ] = [left_moment ]
358+ new_moments : List [circuits .Moment ] = [circuits . Moment ( left_moment ) ]
348359
349- pulled : map [ops .Qid , ops . Gate ]
360+ pulled : dict [ops .Qid , _PhasedXYAndRz ]
350361 for moment in moments :
351362 pulled = {}
352363 new_moment : List [ops .Operation ] = []
@@ -368,7 +379,7 @@ def _multi_moment_pull_through(
368379 if q not in pulled :
369380 pulled [q ] = prev [q ]
370381 prev = pulled
371- new_moments .append (new_moment )
382+ new_moments .append (circuits . Moment ( new_moment ) )
372383
373384 last_moment = circuits .Moment ([pulled [q ].to_single_gate ().on (q ) for q in all_qubits ])
374385
@@ -381,5 +392,5 @@ def _multi_moment_pull_through(
381392CPhaseGaugeTransformerMM = GaugeTransformer (
382393 target = ops .Gateset (ops .CZPowGate , ops .ZPowGate ),
383394 gauge_selector = CPhaseGaugeSelector ,
384- multi_moment_pull_thourgh_fn = _multi_moment_pull_through ,
395+ multi_moment_gauge_fn = _multi_moment_gauge_fn ,
385396)
0 commit comments