Skip to content

Commit d4e0381

Browse files
committed
fix checks
1 parent 8e95b07 commit d4e0381

File tree

4 files changed

+85
-73
lines changed

4 files changed

+85
-73
lines changed

cirq-core/cirq/transformers/gauge_compiling/cphase_gauge.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import List
19+
from typing import cast, List
2020

2121
import numpy as np
2222

@@ -156,12 +156,12 @@ class _PhasedXYAndRz:
156156
The order is --(X|Y|I)--Rz(rad)--phase--.
157157
"""
158158

159-
pauli: ops.X | ops.Y | ops.I
159+
pauli: ops.Pauli | ops.IdentityGate # X|Y|I
160160
rz_rads: float
161161
phase_exp: float # phase of the qubit is e^{i*phase_exp*pi}
162162

163163
def __init__(
164-
self, pauli: ops.Pauli | ops.I = ops.I, rz_rads: float = 0, phase_exp: float = 0
164+
self, pauli: ops.Pauli | ops.IdentityGate = ops.I, rz_rads: float = 0, phase_exp: float = 0
165165
) -> None:
166166
if pauli == ops.Z: # Merge Z gates to Rz where Z = Rz(π) * e^{iπ/2}
167167
self.pauli = ops.I
@@ -183,7 +183,7 @@ def _merge_right_rz(self, rads: float):
183183
"""Merges Rz(rads) from right."""
184184
self.rz_rads += rads
185185

186-
def _merge_left_xy(self, other: ops.X | ops.Y):
186+
def _merge_left_xy(self, other: ops.Pauli):
187187
"""Merges --(X|Y)--self--."""
188188
if self.pauli == other:
189189
self.pauli = ops.I
@@ -202,7 +202,7 @@ def _merge_left_xy(self, other: ops.X | ops.Y):
202202
self.rz_rads -= np.pi
203203
return
204204

205-
def _merge_right_xy(self, other: ops.X | ops.Y):
205+
def _merge_right_xy(self, other: ops.Pauli):
206206
"""Merges --self--(X|Y)--."""
207207
self.rz_rads *= -1
208208
if self.pauli == other:
@@ -227,13 +227,13 @@ def merge_left(self, other: _PhasedXYAndRz) -> None:
227227
self._merge_left_rz(other.rz_rads)
228228
self.phase_exp += other.phase_exp
229229
if other.pauli != ops.I:
230-
self._merge_left_xy(other.pauli)
230+
self._merge_left_xy(cast(ops.Pauli, other.pauli))
231231

232232
def merge_right(self, other: _PhasedXYAndRz) -> None:
233233
"""Inplace merge other from right."""
234234
self.phase_exp += other.phase_exp
235235
if other.pauli != ops.I:
236-
self._merge_right_xy(other.pauli)
236+
self._merge_right_xy(cast(ops.Pauli, other.pauli))
237237
self._merge_right_rz(other.rz_rads)
238238

239239
def after_cphase(
@@ -259,41 +259,47 @@ def after_cphase(
259259
# Similarly for X|Y on qubit 0/1, the result is always flipping cphase and
260260
# add an extra Rz rotation on the other qubit.
261261
return (
262-
cphase**-1,
262+
cast(ops.CZPowGate, cphase**-1),
263263
self,
264264
_PhasedXYAndRz(rz_rads=cphase.exponent * np.pi, phase_exp=-cphase.exponent / 2),
265265
)
266266

267267
def __str__(self) -> str:
268268
return f"─{self.pauli}──Rz({self.rz_rads})──phase(e^{{i{self.phase_exp}π}})─"
269269

270-
def __eq__(self, other: _PhasedXYAndRz) -> bool:
270+
def __eq__(self, other: object) -> bool:
271+
if not isinstance(other, _PhasedXYAndRz):
272+
raise NotImplementedError
271273
return (
272274
self.pauli == other.pauli
273-
and np.isclose(self.rz_rads, other.rz_rads, atol=1e-10)
274-
and np.isclose(self.phase_exp, other.phase_exp, atol=1e-10)
275+
and bool(np.isclose(self.rz_rads, other.rz_rads, atol=1e-10))
276+
and bool(np.isclose(self.phase_exp, other.phase_exp, atol=1e-10))
275277
)
276278

277279
def to_single_gate(self) -> ops.PhasedXZGate | ops.ZPowGate:
278-
if self.pauli == ops.I:
279-
rz_rads = self.rz_rads
280-
if np.isclose(self.rz_rads, 0, atol=1e-2):
281-
rz_rads = self.rz_rads + 4 * np.pi
282-
return ops.ZPowGate(
283-
exponent=rz_rads / np.pi, global_shift=np.pi * self.phase_exp / rz_rads - 0.5
284-
)
285-
if self.pauli == ops.X:
286-
return ops.PhasedXZGate(
287-
x_exponent=1,
288-
z_exponent=2 * self.phase_exp,
289-
axis_phase_exponent=self.rz_rads / 2 / np.pi - self.phase_exp,
290-
)
291-
if self.pauli == ops.Y:
292-
return ops.PhasedXZGate(
293-
x_exponent=1,
294-
z_exponent=2 * self.phase_exp,
295-
axis_phase_exponent=1 / 2 - self.phase_exp + self.rz_rads / 2 / np.pi,
296-
)
280+
"""Converts the _PhasedXYAndRz to a single-qubit gate."""
281+
match self.pauli:
282+
case ops.I:
283+
rz_rads = self.rz_rads
284+
if np.isclose(self.rz_rads, 0, atol=1e-2):
285+
rz_rads = self.rz_rads + 4 * np.pi
286+
return ops.ZPowGate(
287+
exponent=rz_rads / np.pi, global_shift=np.pi * self.phase_exp / rz_rads - 0.5
288+
)
289+
case ops.X:
290+
return ops.PhasedXZGate(
291+
x_exponent=1,
292+
z_exponent=2 * self.phase_exp,
293+
axis_phase_exponent=self.rz_rads / 2 / np.pi - self.phase_exp,
294+
)
295+
case ops.Y:
296+
return ops.PhasedXZGate(
297+
x_exponent=1,
298+
z_exponent=2 * self.phase_exp,
299+
axis_phase_exponent=1 / 2 - self.phase_exp + self.rz_rads / 2 / np.pi,
300+
)
301+
case _:
302+
raise ValueError("Invalid self.pauli.")
297303

298304

299305
def _pull_through_single_cphase(
@@ -327,26 +333,31 @@ def _pull_through_single_cphase(
327333
return output_cphase, output0, output1
328334

329335

330-
def _multi_moment_pull_through(
336+
def _multi_moment_gauge_fn(
331337
moments: List[circuits.Moment], rng: np.random.Generator
332338
) -> List[circuits.Moment]:
333-
"""TO FILL."""
339+
"""Generates a left layer with random generator, then pulling through all the moments."""
334340
all_qubits = [q for q in circuits.Circuit(moments).all_qubits()]
335341
if not all_qubits:
336342
return moments
337343
if not any(isinstance(op.gate, ops.CZPowGate) for moment in moments for op in moment):
338344
return moments
339345

340-
left_moment = circuits.Moment(
341-
[rng.choice([ops.I, ops.X, ops.Y, ops.Z]).on(q) for q in all_qubits]
342-
)
343-
prev: map[ops.Qid, ops.Gate] = {
344-
op.qubits[0]: _PhasedXYAndRz(pauli=op.gate) for op in left_moment
346+
left_moment: List[ops.Operation] = [
347+
rng.choice(
348+
np.array([ops.I, ops.X, ops.Y, ops.Z], dtype=ops.Gate), p=[0.25, 0.25, 0.25, 0.25]
349+
).on(q)
350+
for q in all_qubits
351+
]
352+
prev: dict[ops.Qid, _PhasedXYAndRz] = {
353+
op.qubits[0]: _PhasedXYAndRz(pauli=cast(ops.Pauli | ops.IdentityGate, op.gate))
354+
for op in left_moment
355+
if op.gate
345356
}
346357

347-
new_moments: List[circuits.Moment] = [left_moment]
358+
new_moments: List[circuits.Moment] = [circuits.Moment(left_moment)]
348359

349-
pulled: map[ops.Qid, ops.Gate]
360+
pulled: dict[ops.Qid, _PhasedXYAndRz]
350361
for moment in moments:
351362
pulled = {}
352363
new_moment: List[ops.Operation] = []
@@ -368,7 +379,7 @@ def _multi_moment_pull_through(
368379
if q not in pulled:
369380
pulled[q] = prev[q]
370381
prev = pulled
371-
new_moments.append(new_moment)
382+
new_moments.append(circuits.Moment(new_moment))
372383

373384
last_moment = circuits.Moment([pulled[q].to_single_gate().on(q) for q in all_qubits])
374385

@@ -377,9 +388,9 @@ def _multi_moment_pull_through(
377388
return new_moments
378389

379390

380-
# Multi-moments pull through version of CZGaugeTransformer
391+
# Multi-moments pull through version of CPhaseGaugeTransformer
381392
CPhaseGaugeTransformerMM = GaugeTransformer(
382393
target=ops.Gateset(ops.CZPowGate, ops.ZPowGate),
383394
gauge_selector=CPhaseGaugeSelector,
384-
multi_moment_pull_thourgh_fn=_multi_moment_pull_through,
395+
multi_moment_gauge_fn=_multi_moment_gauge_fn,
385396
)

cirq-core/cirq/transformers/gauge_compiling/cphase_gauge_test.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414

1515
from copy import deepcopy
16+
1617
import numpy as np
1718

1819
import cirq
1920
from cirq.transformers.gauge_compiling.cphase_gauge import (
21+
_PhasedXYAndRz,
2022
CPhaseGaugeTransformer,
2123
CPhaseGaugeTransformerMM,
22-
_PhasedXYAndRz,
2324
)
2425
from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester
2526

@@ -115,23 +116,23 @@ def test_multi_layer_pull_through():
115116
"""Test case.
116117
Input:
117118
┌──┐
118-
0: ───@────────@─────H───────────@───────@───────
119-
│ │ │ │
120-
1: ───@^0.2────┼@────────────────@^0.1───@───────
119+
0: ───@────────@─────H───Rz(-0.255π)───@───────@───────
120+
│ │ │ │
121+
1: ───@^0.2────┼@──────────────────────@^0.1───@───────
121122
││
122-
2: ───@────────@┼────────@───────@───────@───────
123-
│ │ │ │ │
124-
3: ───@─────────@────────@^0.2───@───────@^0.2───
123+
2: ───@────────@┼────────@─────────────@───────@───────
124+
│ │ │ │ │
125+
3: ───@─────────@────────@^0.2─────────@───────@^0.2───
125126
└──┘
126127
Example output:
127128
┌──┐
128-
0: ───Y───@─────────@─────PhXZ(a=0.5,x=1,z=0)───H───Y────────────@────────@────────PhXZ(a=0.5,x=1,z=0)───
129-
│ │ │ │
130-
1: ───Z───@^-0.2────┼@────Z^-0.8────────────────────I────────────@^-0.1───@────────Z^-0.9────────────────
129+
0: ───Z───@─────────@─────Z^0.2────────────────H───I───────────@────────@───────Z^0.845────────────────── # pylint: disable=line-too-long
130+
│ │ │ │
131+
1: ───X───@^-0.2────┼@────PhXZ(a=0,x=1,z=1)────────X───────────@^-0.1───@───────PhXZ(a=0,x=1,z=0)──────── # pylint: disable=line-too-long
131132
││
132-
2: ───Z───@─────────@┼────Z^0───────────────────────Y───@───────@────────@───────PhXZ(a=0.5,x=1,z=0)───
133-
│ │ │ │
134-
3: ───I───@──────────@────Z^0───────────────────────I───@^-0.2───@────────@^-0.2───Z^-0.6────────────────
133+
2: ───X───@─────────@┼────PhXZ(a=0,x=1,z=1)────────Y───@───────@────────@───────PhXZ(a=0.5,x=1,z=1.4)──── # pylint: disable=line-too-long
134+
│ │ │ │
135+
3: ───X───@──────────@────PhXZ(a=2,x=1,z=-2)───────Y───@^0.2───@────────@^0.2───PhXZ(a=1.9,x=1,z=-1.4)─── # pylint: disable=line-too-long
135136
└──┘
136137
"""
137138
q0, q1, q2, q3 = cirq.LineQubit.range(4)
@@ -140,7 +141,7 @@ def test_multi_layer_pull_through():
140141
cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.CZ(q2, q3)),
141142
cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)),
142143
cirq.Moment(cirq.H(q0)),
143-
cirq.Moment(cirq.CZ(q2, q3) ** 0.2),
144+
cirq.Moment(cirq.CZ(q2, q3) ** 0.2, cirq.Rz(rads=-0.8).on(q0)),
144145
cirq.Moment(cirq.CZ(q0, q1) ** 0.1, cirq.CZ(q2, q3)),
145146
cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3) ** 0.2),
146147
)

cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,17 @@
4848
)
4949

5050

51-
def _multi_moment_pull_through(
51+
def _multi_moment_gauge_fn(
5252
moments: List[circuits.Moment], rng: np.random.Generator
5353
) -> List[circuits.Moment]:
5454
# Check all the ops are CZ first
5555
if not all(op.gate == CZ for moment in moments for op in moment):
5656
raise ValueError(f"Input moments must only contain CZ gates:\nmoments = {moments}.")
5757

5858
left: List[ops.Operation] = [
59-
rng.choice([ops.I, ops.X, ops.Y, ops.Z]).on(q)
59+
rng.choice(
60+
np.array([ops.I, ops.X, ops.Y, ops.Z], dtype=ops.Gate), p=[0.25, 0.25, 0.25, 0.25]
61+
).on(q)
6062
for q in circuits.Circuit(moments).all_qubits()
6163
]
6264
if not left:
@@ -73,7 +75,5 @@ def _multi_moment_pull_through(
7375

7476
# Multi-moments pull through version of CZGaugeTransformer
7577
CZGaugeTransformerMM = GaugeTransformer(
76-
target=CZ,
77-
gauge_selector=CZGaugeSelector,
78-
multi_moment_pull_thourgh_fn=_multi_moment_pull_through,
78+
target=CZ, gauge_selector=CZGaugeSelector, multi_moment_gauge_fn=_multi_moment_gauge_fn
7979
)

cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def __init__(
196196
target: Union[ops.Gate, ops.Gateset, ops.GateFamily],
197197
gauge_selector: Callable[[np.random.Generator], Gauge],
198198
two_qubit_gate_symbolizer: Optional[TwoQubitGateSymbolizer] = None,
199-
multi_moment_pull_thourgh_fn: Optional[
200-
Callable[[List[circuits.Moment], List[circuits.Moment]], List[circuits.Moment]]
199+
multi_moment_gauge_fn: Optional[
200+
Callable[[List[circuits.Moment], np.random.Generator], List[circuits.Moment]]
201201
] = None,
202202
) -> None:
203203
"""Constructs a GaugeTransformer.
@@ -211,7 +211,7 @@ def __init__(
211211
self.target = ops.GateFamily(target) if isinstance(target, ops.Gate) else target
212212
self.gauge_selector = gauge_selector
213213
self.two_qubit_gate_symbolizer = two_qubit_gate_symbolizer
214-
self.multi_moment_pull_thourgh_fn = multi_moment_pull_thourgh_fn
214+
self.multi_moment_gauge_fn = multi_moment_gauge_fn
215215

216216
def __call__(
217217
self,
@@ -225,22 +225,22 @@ def __call__(
225225
context = transformer_api.TransformerContext(deep=False)
226226
if context.deep:
227227
raise ValueError('GaugeTransformer cannot be used with deep=True')
228-
new_moments = []
228+
new_moments: List[circuits.Moment] = []
229229
left: List[List[ops.Operation]] = []
230230
right: List[List[ops.Operation]] = []
231231
all_target_moments: List[circuits.Moment] = []
232232

233233
for moment in circuit:
234-
if self.multi_moment_pull_thourgh_fn and all(
234+
if self.multi_moment_gauge_fn and all(
235235
[
236236
op in self.target and not set(op.tags).intersection(context.tags_to_ignore)
237237
for op in moment
238238
]
239239
): # all ops are target 2-qubit gates
240240
all_target_moments.append(moment)
241241
continue
242-
if all_target_moments:
243-
new_moments.extend(self.multi_moment_pull_thourgh_fn(all_target_moments, rng))
242+
if all_target_moments and self.multi_moment_gauge_fn:
243+
new_moments.extend(self.multi_moment_gauge_fn(all_target_moments, rng))
244244
all_target_moments.clear()
245245

246246
left.clear()
@@ -262,11 +262,11 @@ def __call__(
262262
center.append(op)
263263
if left:
264264
new_moments.extend(_build_moments(left))
265-
new_moments.append(center)
265+
new_moments.append(circuits.Moment(center))
266266
if right:
267267
new_moments.extend(_build_moments(right))
268-
if all_target_moments:
269-
new_moments.extend(self.multi_moment_pull_thourgh_fn(all_target_moments, rng))
268+
if all_target_moments and self.multi_moment_gauge_fn:
269+
new_moments.extend(self.multi_moment_gauge_fn(all_target_moments, rng))
270270
return circuits.Circuit.from_moments(*new_moments)
271271

272272
def as_sweep(
@@ -412,15 +412,15 @@ def two_qubit_gate_next_symbol_list(n: int) -> List[sympy.Symbol]:
412412
return circuits.Circuit.from_moments(*new_moments), Zip(*sweeps)
413413

414414

415-
def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[ops.Operation]]:
415+
def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[circuits.Moment]:
416416
"""Builds moments from a list of operations grouped by qubits.
417417
418418
Returns a list of moments from a list whose ith element is a list of operations applied
419419
to qubit i.
420420
"""
421421
moments = []
422422
for moment in itertools.zip_longest(*operation_by_qubits):
423-
moments.append([op for op in moment if op is not None])
423+
moments.append(circuits.Moment([op for op in moment if op is not None]))
424424
return moments
425425

426426

0 commit comments

Comments
 (0)