Skip to content

Commit 99a697e

Browse files
authored
Add support for converting OBSERVABLE_INCLUDE instructions with pauli targets to stimcirq (#952)
1 parent c6e96b7 commit 99a697e

File tree

4 files changed

+65
-10
lines changed

4 files changed

+65
-10
lines changed

glue/cirq/stimcirq/_obs_annotation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
*,
1717
parity_keys: Iterable[str] = (),
1818
relative_keys: Iterable[int] = (),
19+
pauli_keys: Iterable[str] = (),
1920
observable_index: int,
2021
):
2122
"""
@@ -28,6 +29,7 @@ def __init__(
2829
"""
2930
self.parity_keys = frozenset(parity_keys)
3031
self.relative_keys = frozenset(relative_keys)
32+
self.pauli_keys = frozenset(pauli_keys)
3133
self.observable_index = observable_index
3234

3335
@property
@@ -38,11 +40,12 @@ def with_qubits(self, *new_qubits) -> 'CumulativeObservableAnnotation':
3840
return self
3941

4042
def _value_equality_values_(self) -> Any:
41-
return self.parity_keys, self.relative_keys, self.observable_index
43+
return self.parity_keys, self.relative_keys, self.pauli_keys, self.observable_index
4244

4345
def _circuit_diagram_info_(self, args: Any) -> str:
4446
items: List[str] = [repr(e) for e in sorted(self.parity_keys)]
4547
items += [f'rec[{e}]' for e in sorted(self.relative_keys)]
48+
items += sorted(self.pauli_keys)
4649
k = ",".join(str(e) for e in items)
4750
return f"Obs{self.observable_index}({k})"
4851

@@ -51,6 +54,7 @@ def __repr__(self) -> str:
5154
f'stimcirq.CumulativeObservableAnnotation('
5255
f'parity_keys={sorted(self.parity_keys)}, '
5356
f'relative_keys={sorted(self.relative_keys)}, '
57+
f'pauli_keys={sorted(self.pauli_keys)}, '
5458
f'observable_index={self.observable_index!r})'
5559
)
5660

@@ -62,6 +66,7 @@ def _json_dict_(self) -> Dict[str, Any]:
6266
result = {
6367
'parity_keys': sorted(self.parity_keys),
6468
'observable_index': self.observable_index,
69+
'pauli_keys': sorted(self.pauli_keys),
6570
}
6671
if self.relative_keys:
6772
result['relative_keys'] = sorted(self.relative_keys)
@@ -104,6 +109,12 @@ def _stim_conversion_(
104109
rec_targets.append(stim.target_rec(-1 - offset))
105110
if not remaining:
106111
break
112+
rec_targets.extend(
113+
[
114+
stim.target_pauli(qubit_index=int(k[1:]), pauli=k[0])
115+
for k in sorted(self.pauli_keys)
116+
]
117+
)
107118
if remaining:
108119
raise ValueError(
109120
f"{self!r} was processed before measurements it referenced ({sorted(remaining)!r})."

glue/cirq/stimcirq/_obs_annotation_test.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,30 @@ def test_json_serialization():
191191
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
192192
assert c == c2
193193

194+
def test_json_serialization_with_pauli_keys():
195+
c = cirq.Circuit(
196+
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
197+
stimcirq.CumulativeObservableAnnotation(
198+
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]
199+
),
200+
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=["X0", "Y1", "Z2"]),
201+
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
202+
)
203+
json = cirq.to_json(c)
204+
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
205+
assert c == c2
206+
194207

195208
def test_json_backwards_compat_exact():
196209
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5)
197-
packed = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}'
198-
assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
199-
assert cirq.to_json(raw) == packed
210+
packed_v1 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}'
211+
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
212+
assert cirq.read_json(json_text=packed_v1, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
213+
assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
214+
assert cirq.to_json(raw) == packed_v2
215+
216+
# With pauli_keys
217+
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=["X0", "Y1", "Z2"])
218+
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n "X0",\n "Y1",\n "Z2"\n ],\n "relative_keys": [\n -2\n ]\n}'
219+
assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
220+
assert cirq.to_json(raw) == packed_v2

glue/cirq/stimcirq/_stim_to_cirq.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,19 +340,26 @@ def coords_after_offset(
340340

341341
def resolve_measurement_record_keys(
342342
self, targets: Iterable[stim.GateTarget]
343-
) -> Tuple[List[str], List[int]]:
343+
) -> Tuple[List[str], List[int], List[str]]:
344+
pauli_targets, meas_targets = [], []
345+
for t in targets:
346+
if t.is_measurement_record_target:
347+
meas_targets.append(t)
348+
else:
349+
pauli_targets.append(f'{t.pauli_type}{t.value}')
350+
344351
if self.have_seen_loop:
345-
return [], [t.value for t in targets]
352+
return [], [t.value for t in meas_targets], pauli_targets
346353
else:
347-
return [str(self.num_measurements_seen + t.value) for t in targets], []
354+
return [str(self.num_measurements_seen + t.value) for t in meas_targets], [], pauli_targets
348355

349356
def process_detector(self, instruction: stim.CircuitInstruction) -> None:
350357
if instruction.tag:
351358
tags = [instruction.tag]
352359
else:
353360
tags = ()
354361
coords = self.coords_after_offset(instruction.gate_args_copy())
355-
keys, rels = self.resolve_measurement_record_keys(instruction.targets_copy())
362+
keys, rels, _ = self.resolve_measurement_record_keys(instruction.targets_copy())
356363
self.append_operation(
357364
DetAnnotation(parity_keys=keys, relative_keys=rels, coordinate_metadata=coords).with_tags(*tags)
358365
)
@@ -364,10 +371,10 @@ def process_observable_include(self, instruction: stim.CircuitInstruction) -> No
364371
tags = ()
365372
args = instruction.gate_args_copy()
366373
index = 0 if not args else int(args[0])
367-
keys, rels = self.resolve_measurement_record_keys(instruction.targets_copy())
374+
keys, rels, paulis = self.resolve_measurement_record_keys(instruction.targets_copy())
368375
self.append_operation(
369376
CumulativeObservableAnnotation(
370-
parity_keys=keys, relative_keys=rels, observable_index=index
377+
parity_keys=keys, relative_keys=rels, pauli_keys=paulis, observable_index=index
371378
).with_tags(*tags)
372379
)
373380

glue/cirq/stimcirq/_stim_to_cirq_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,19 @@ def test_id_error_round_trip():
763763
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit)
764764
restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
765765
assert restored_circuit == stim_circuit
766+
767+
def test_round_trip_with_pauli_obs():
768+
stim_circuit = stim.Circuit("""
769+
QUBIT_COORDS(5, 5) 0
770+
R 0
771+
OBSERVABLE_INCLUDE(0) X0
772+
TICK
773+
H 0
774+
TICK
775+
M 0
776+
OBSERVABLE_INCLUDE(0) rec[-1]
777+
TICK
778+
""")
779+
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit)
780+
restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
781+
assert restored_circuit == stim_circuit

0 commit comments

Comments
 (0)