Skip to content

Commit 8e95b07

Browse files
committed
cphase multi moments gauge transformer
1 parent fd2e0e2 commit 8e95b07

File tree

2 files changed

+393
-2
lines changed

2 files changed

+393
-2
lines changed

cirq-core/cirq/transformers/gauge_compiling/cphase_gauge.py

Lines changed: 238 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from __future__ import annotations
1818

19+
from typing import List
20+
1921
import numpy as np
2022

2123
import cirq.transformers.gauge_compiling.sqrt_cz_gauge as sqrt_cz_gauge
22-
from cirq import ops
24+
from cirq import circuits, ops
2325
from cirq.transformers.gauge_compiling.gauge_compiling import (
2426
ConstantGauge,
2527
Gauge,
@@ -146,3 +148,238 @@ def sample(self, gate: ops.Gate, prng: np.random.Generator) -> ConstantGauge:
146148
symbolizer_fn=sqrt_cz_gauge._symbolize_as_cz_pow, n_symbols=1
147149
),
148150
)
151+
152+
153+
class _PhasedXYAndRz:
154+
"""In pulling through, one qubit gate can be represented by a Pauli and an Rz gate.
155+
156+
The order is --(X|Y|I)--Rz(rad)--phase--.
157+
"""
158+
159+
pauli: ops.X | ops.Y | ops.I
160+
rz_rads: float
161+
phase_exp: float # phase of the qubit is e^{i*phase_exp*pi}
162+
163+
def __init__(
164+
self, pauli: ops.Pauli | ops.I = ops.I, rz_rads: float = 0, phase_exp: float = 0
165+
) -> None:
166+
if pauli == ops.Z: # Merge Z gates to Rz where Z = Rz(π) * e^{iπ/2}
167+
self.pauli = ops.I
168+
self.rz_rads = rz_rads + np.pi
169+
self.phase_exp = phase_exp + 0.5
170+
else:
171+
self.pauli = pauli
172+
self.rz_rads = rz_rads
173+
self.phase_exp = phase_exp
174+
175+
def _merge_left_rz(self, rads: float):
176+
"""Merges Rz(rad) from left."""
177+
if self.pauli == ops.I:
178+
self.rz_rads += rads
179+
else:
180+
self.rz_rads -= rads
181+
182+
def _merge_right_rz(self, rads: float):
183+
"""Merges Rz(rads) from right."""
184+
self.rz_rads += rads
185+
186+
def _merge_left_xy(self, other: ops.X | ops.Y):
187+
"""Merges --(X|Y)--self--."""
188+
if self.pauli == other:
189+
self.pauli = ops.I
190+
return
191+
if self.pauli == ops.I:
192+
self.pauli = other
193+
return
194+
if (other, self.pauli) == (ops.X, ops.Y):
195+
# -X--Y- ==> --Rz(pi)--
196+
self.pauli = ops.I
197+
self.rz_rads += np.pi
198+
return
199+
if (other, self.pauli) == (ops.Y, ops.X):
200+
# -Y--X- ==> --Rz(-pi)--
201+
self.pauli = ops.I
202+
self.rz_rads -= np.pi
203+
return
204+
205+
def _merge_right_xy(self, other: ops.X | ops.Y):
206+
"""Merges --self--(X|Y)--."""
207+
self.rz_rads *= -1
208+
if self.pauli == other:
209+
self.pauli = ops.I
210+
return
211+
if self.pauli == ops.I:
212+
self.pauli = other
213+
return
214+
if (self.pauli, other) == (ops.X, ops.Y):
215+
# -X--Y- ==> --Rz(pi)--
216+
self.pauli = ops.I
217+
self.rz_rads += np.pi
218+
return
219+
if (self.pauli, other) == (ops.Y, ops.X):
220+
# -X--Y- ==> --Rz(-pi)--
221+
self.pauli = ops.I
222+
self.rz_rads -= np.pi
223+
return
224+
225+
def merge_left(self, other: _PhasedXYAndRz) -> None:
226+
"""Inplace merge other from left."""
227+
self._merge_left_rz(other.rz_rads)
228+
self.phase_exp += other.phase_exp
229+
if other.pauli != ops.I:
230+
self._merge_left_xy(other.pauli)
231+
232+
def merge_right(self, other: _PhasedXYAndRz) -> None:
233+
"""Inplace merge other from right."""
234+
self.phase_exp += other.phase_exp
235+
if other.pauli != ops.I:
236+
self._merge_right_xy(other.pauli)
237+
self._merge_right_rz(other.rz_rads)
238+
239+
def after_cphase(
240+
self, cphase: ops.CZPowGate
241+
) -> tuple[ops.CZPowGate, _PhasedXYAndRz, _PhasedXYAndRz]:
242+
"""Pull self through cphase.
243+
244+
Returns:
245+
updated cphase gate, pull_through of this qubit, pull_through of the other qubit.
246+
"""
247+
match self.pauli:
248+
case ops.I:
249+
return cphase, self, _PhasedXYAndRz()
250+
case _: # ops.X | ops.Y:
251+
# Taking input0 with X gate as an example:
252+
# 0: ─X─Rz(t)─phase─@────── 0: ─X──@─────Rz(t)──phase─
253+
# │ ==> │
254+
# 1: ───────────────@^exp── 1: ────@^exp──────────────
255+
# 0: ─@──────X────Rz(t)───phase─────────
256+
# ==> │
257+
# 1: ─@^-exp─Rz(exp pi)─e^{-exp pi/2 i}─
258+
# where rad = -exp * pi.
259+
# Similarly for X|Y on qubit 0/1, the result is always flipping cphase and
260+
# add an extra Rz rotation on the other qubit.
261+
return (
262+
cphase**-1,
263+
self,
264+
_PhasedXYAndRz(rz_rads=cphase.exponent * np.pi, phase_exp=-cphase.exponent / 2),
265+
)
266+
267+
def __str__(self) -> str:
268+
return f"─{self.pauli}──Rz({self.rz_rads})──phase(e^{{i{self.phase_exp}π}})─"
269+
270+
def __eq__(self, other: _PhasedXYAndRz) -> bool:
271+
return (
272+
self.pauli == other.pauli
273+
and np.isclose(self.rz_rads, other.rz_rads, atol=1e-10)
274+
and np.isclose(self.phase_exp, other.phase_exp, atol=1e-10)
275+
)
276+
277+
def to_single_gate(self) -> ops.PhasedXZGate | ops.ZPowGate:
278+
if self.pauli == ops.I:
279+
rz_rads = self.rz_rads
280+
if np.isclose(self.rz_rads, 0, atol=1e-2):
281+
rz_rads = self.rz_rads + 4 * np.pi
282+
return ops.ZPowGate(
283+
exponent=rz_rads / np.pi, global_shift=np.pi * self.phase_exp / rz_rads - 0.5
284+
)
285+
if self.pauli == ops.X:
286+
return ops.PhasedXZGate(
287+
x_exponent=1,
288+
z_exponent=2 * self.phase_exp,
289+
axis_phase_exponent=self.rz_rads / 2 / np.pi - self.phase_exp,
290+
)
291+
if self.pauli == ops.Y:
292+
return ops.PhasedXZGate(
293+
x_exponent=1,
294+
z_exponent=2 * self.phase_exp,
295+
axis_phase_exponent=1 / 2 - self.phase_exp + self.rz_rads / 2 / np.pi,
296+
)
297+
298+
299+
def _pull_through_single_cphase(
300+
cphase: ops.CZPowGate, input0: _PhasedXYAndRz, input1: _PhasedXYAndRz
301+
) -> tuple[ops.CZPowGate, _PhasedXYAndRz, _PhasedXYAndRz]:
302+
"""Pulls input0 and input1 through a CZPowGate.
303+
Input:
304+
0: ─(input0=P0──Rz0──phase0)─@─────
305+
306+
1: ─(input1=P1──Rz1──phase1)─@^exp─
307+
Output:
308+
0: ─@────────(output0=P0'──Rz0'──phase0')─
309+
310+
1: ─@^+/-exp─(output1=P1'──Rz1'──phase1')─
311+
"""
312+
313+
# Step 1; pull input0 through CZPowGate.
314+
# 0: ─input0─@───── 0: ────────@─────────output0─
315+
# │ ==> │
316+
# 1: ─input1─@^exp─ 1: ─input1─@^+/-exp──output1─
317+
output_cphase, output0, output1 = input0.after_cphase(cphase)
318+
319+
# Step 2; similar to step 1, pull input1 through CZPowGate.
320+
# 0: ─@──────────pulled0────output0─ 0: ─@────────output0─
321+
# ==> │ ==> │
322+
# 1: ─@^+/-exp───pulled1────output1─ 1: ─@^+/-exp─output1─
323+
output_cphase, pulled1, pulled0 = input1.after_cphase(output_cphase)
324+
output0.merge_left(pulled0)
325+
output1.merge_left(pulled1)
326+
327+
return output_cphase, output0, output1
328+
329+
330+
def _multi_moment_pull_through(
331+
moments: List[circuits.Moment], rng: np.random.Generator
332+
) -> List[circuits.Moment]:
333+
"""TO FILL."""
334+
all_qubits = [q for q in circuits.Circuit(moments).all_qubits()]
335+
if not all_qubits:
336+
return moments
337+
if not any(isinstance(op.gate, ops.CZPowGate) for moment in moments for op in moment):
338+
return moments
339+
340+
left_moment = circuits.Moment(
341+
[rng.choice([ops.I, ops.X, ops.Y, ops.Z]).on(q) for q in all_qubits]
342+
)
343+
prev: map[ops.Qid, ops.Gate] = {
344+
op.qubits[0]: _PhasedXYAndRz(pauli=op.gate) for op in left_moment
345+
}
346+
347+
new_moments: List[circuits.Moment] = [left_moment]
348+
349+
pulled: map[ops.Qid, ops.Gate]
350+
for moment in moments:
351+
pulled = {}
352+
new_moment: List[ops.Operation] = []
353+
for op in moment:
354+
if op.gate and (isinstance(op.gate, ops.CZPowGate)):
355+
q0, q1 = op.qubits
356+
cphase_gate, pulled[q0], pulled[q1] = _pull_through_single_cphase(
357+
op.gate, prev[q0], prev[q1]
358+
)
359+
new_moment.append(cphase_gate.on(q0, q1))
360+
elif op.gate and isinstance(op.gate, ops.ZPowGate):
361+
q = op.qubits[0]
362+
pulled[q] = prev[q]
363+
pulled[q].merge_right(_PhasedXYAndRz(rz_rads=op.gate.exponent * np.pi))
364+
# Don't need to add the op in the new_moment as it is already merged into pulled.
365+
else:
366+
new_moment.append(op)
367+
for q in all_qubits:
368+
if q not in pulled:
369+
pulled[q] = prev[q]
370+
prev = pulled
371+
new_moments.append(new_moment)
372+
373+
last_moment = circuits.Moment([pulled[q].to_single_gate().on(q) for q in all_qubits])
374+
375+
new_moments.append(last_moment)
376+
377+
return new_moments
378+
379+
380+
# Multi-moments pull through version of CZGaugeTransformer
381+
CPhaseGaugeTransformerMM = GaugeTransformer(
382+
target=ops.Gateset(ops.CZPowGate, ops.ZPowGate),
383+
gauge_selector=CPhaseGaugeSelector,
384+
multi_moment_pull_thourgh_fn=_multi_moment_pull_through,
385+
)

0 commit comments

Comments
 (0)