Skip to content

Commit 9d511e9

Browse files
authored
Extend include_tags to support by-type selection (#7425)
implements #2905 based on the comments [here](#2905 (comment)). closed previous PR #7415
1 parent bb7b5c1 commit 9d511e9

File tree

7 files changed

+119
-23
lines changed

7 files changed

+119
-23
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def to_text_diagram(
11721172
*,
11731173
use_unicode_characters: bool = True,
11741174
transpose: bool = False,
1175-
include_tags: bool = True,
1175+
include_tags: bool | Iterable[type] = True,
11761176
precision: int | None = 3,
11771177
qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT,
11781178
) -> str:
@@ -1182,7 +1182,10 @@ def to_text_diagram(
11821182
use_unicode_characters: Determines if unicode characters are
11831183
allowed (as opposed to ascii-only diagrams).
11841184
transpose: Arranges qubit wires vertically instead of horizontally.
1185-
include_tags: Whether tags on TaggedOperations should be printed
1185+
include_tags: Controls which tags attached to operations are
1186+
included. ``True`` includes all tags, ``False`` includes none,
1187+
or a collection of tag classes may be specified to include only
1188+
those tags.
11861189
precision: Number of digits to display in text diagram
11871190
qubit_order: Determines how qubits are ordered in the diagram.
11881191
@@ -1209,7 +1212,7 @@ def to_text_diagram_drawer(
12091212
use_unicode_characters: bool = True,
12101213
qubit_namer: Callable[[cirq.Qid], str] | None = None,
12111214
transpose: bool = False,
1212-
include_tags: bool = True,
1215+
include_tags: bool | Iterable[type] = True,
12131216
draw_moment_groups: bool = True,
12141217
precision: int | None = 3,
12151218
qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT,
@@ -1224,7 +1227,10 @@ def to_text_diagram_drawer(
12241227
allowed (as opposed to ascii-only diagrams).
12251228
qubit_namer: Names qubits in diagram. Defaults to using _circuit_diagram_info_ or str.
12261229
transpose: Arranges qubit wires vertically instead of horizontally.
1227-
include_tags: Whether to include tags in the operation.
1230+
include_tags: Controls which tags attached to operations are
1231+
included. ``True`` includes all tags, ``False`` includes none,
1232+
or a collection of tag classes may be specified to include only
1233+
those tags.
12281234
draw_moment_groups: Whether to draw moment symbol or not
12291235
precision: Number of digits to use when representing numbers.
12301236
qubit_order: Determines how qubits are ordered in the diagram.
@@ -2534,7 +2540,7 @@ def _draw_moment_annotations(
25342540
get_circuit_diagram_info: Callable[
25352541
[cirq.Operation, cirq.CircuitDiagramInfoArgs], cirq.CircuitDiagramInfo
25362542
],
2537-
include_tags: bool,
2543+
include_tags: bool | Iterable[type],
25382544
first_annotation_row: int,
25392545
transpose: bool,
25402546
):
@@ -2566,7 +2572,7 @@ def _draw_moment_in_diagram(
25662572
get_circuit_diagram_info: (
25672573
Callable[[cirq.Operation, cirq.CircuitDiagramInfoArgs], cirq.CircuitDiagramInfo] | None
25682574
),
2569-
include_tags: bool,
2575+
include_tags: bool | Iterable[type],
25702576
first_annotation_row: int,
25712577
transpose: bool,
25722578
):
@@ -2637,8 +2643,16 @@ def _draw_moment_in_diagram(
26372643
desc = _formatted_phase(global_phase, use_unicode_characters, precision)
26382644
if desc:
26392645
y = max(label_map.values(), default=0) + 1
2640-
if tags and include_tags:
2641-
desc = desc + f"[{', '.join(map(str, tags))}]"
2646+
visible_tags = protocols.CircuitDiagramInfoArgs(
2647+
known_qubits=None,
2648+
known_qubit_count=None,
2649+
use_unicode_characters=True,
2650+
precision=None,
2651+
label_map=None,
2652+
include_tags=include_tags,
2653+
).tags_to_include(tags)
2654+
if visible_tags:
2655+
desc = desc + f"[{', '.join(map(str, visible_tags))}]"
26422656
out_diagram.write(x0, y, desc)
26432657

26442658
if not non_global_ops:

cirq-core/cirq/circuits/moment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def to_text_diagram(
565565
extra_qubits: Iterable[cirq.Qid] = (),
566566
use_unicode_characters: bool = True,
567567
precision: int | None = None,
568-
include_tags: bool = True,
568+
include_tags: bool | Iterable[type] = True,
569569
) -> str:
570570
"""Create a text diagram for the moment.
571571
@@ -583,8 +583,10 @@ def to_text_diagram(
583583
precision: How precise numbers, such as angles, should be. Use None
584584
for infinite precision, or an integer for a certain number of
585585
digits of precision.
586-
include_tags: Whether or not to include operation tags in the
587-
diagram.
586+
include_tags: Controls which tags attached to operations are
587+
included. ``True`` includes all tags, ``False`` includes none,
588+
or a collection of tag classes may be specified to include only
589+
those tags.
588590
589591
Returns:
590592
The text diagram rendered into text.

cirq-core/cirq/contrib/quantum_volume/quantum_volume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def sample_heavy_set(
127127
# Add measure gates to the end of (a copy of) the circuit. Ensure that those
128128
# gates measure those in the given mapping, preserving this order.
129129
qubits = circuit.all_qubits()
130-
key = None
130+
key: Callable[[cirq.Qid], cirq.Qid] | None = None
131131
if mapping:
132132
# Add any qubits that were not explicitly mapped, so they aren't lost in
133133
# the sorting.
@@ -137,7 +137,7 @@ def sample_heavy_set(
137137
# Don't do a single large measurement gate because then the key will be one
138138
# large string. Instead, do a bunch of single-qubit measurement gates so we
139139
# preserve the qubit keys.
140-
sorted_qubits = sorted(qubits, key=key) # type: ignore[arg-type]
140+
sorted_qubits = sorted(qubits, key=key)
141141
circuit_copy = circuit + [cirq.measure(q) for q in sorted_qubits]
142142

143143
# Run the sampler to compare each output against the Heavy Set.

cirq-core/cirq/ops/raw_types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,12 @@ def _resolve_parameters_(
913913

914914
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
915915
sub_op_info = protocols.circuit_diagram_info(self.sub_operation, args, NotImplemented)
916-
# Add tag to wire symbol if it exists.
917-
if sub_op_info is not NotImplemented and args.include_tags and sub_op_info.wire_symbols:
918-
sub_op_info.wire_symbols = (
919-
sub_op_info.wire_symbols[0] + f"[{', '.join(map(str, self._tags))}]",
920-
) + sub_op_info.wire_symbols[1:]
916+
if sub_op_info is not NotImplemented and sub_op_info.wire_symbols:
917+
visible_tags = args.tags_to_include(self._tags)
918+
if visible_tags:
919+
sub_op_info.wire_symbols = (
920+
sub_op_info.wire_symbols[0] + f"[{', '.join(map(str, visible_tags))}]",
921+
) + sub_op_info.wire_symbols[1:]
921922
return sub_op_info
922923

923924
@cached_method

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,8 @@ def __str__(self):
554554
diagram_with_non_string_tag = "(1, 1): ───H[<taggy>]───"
555555
assert c.to_text_diagram() == diagram_with_non_string_tag
556556
assert c.to_text_diagram(include_tags=False) == diagram_without_tags
557+
assert c.to_text_diagram(include_tags={str}) == diagram_without_tags
558+
assert c.to_text_diagram(include_tags={TaggyTag}) == diagram_with_non_string_tag
557559

558560

559561
def test_circuit_diagram_tagged_global_phase() -> None:

cirq-core/cirq/protocols/circuit_diagram_info_protocol.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ class CircuitDiagramInfoArgs:
181181
precision: The number of digits after the decimal to show for numbers in
182182
the text diagram. None means use full precision.
183183
label_map: The map from label entities to diagram positions.
184-
include_tags: Whether to print tags from TaggedOperations.
184+
include_tags: If ``True`` all tags from ``TaggedOperations`` will be
185+
printed. If ``False`` no tags will be printed. Alternatively a
186+
collection of tag classes can be provided. In this case only tags
187+
whose type is contained in the collection will be shown.
185188
transpose: Whether the circuit is to be drawn with time from left to
186189
right (transpose is False), or from top to bottom.
187190
"""
@@ -195,15 +198,19 @@ def __init__(
195198
use_unicode_characters: bool,
196199
precision: int | None,
197200
label_map: dict[LabelEntity, int] | None,
198-
include_tags: bool = True,
201+
include_tags: bool | Iterable[type] = True,
199202
transpose: bool = False,
200203
) -> None:
201204
self.known_qubits = None if known_qubits is None else tuple(known_qubits)
202205
self.known_qubit_count = known_qubit_count
203206
self.use_unicode_characters = use_unicode_characters
204207
self.precision = precision
205208
self.label_map = label_map
206-
self.include_tags = include_tags
209+
self.include_tags: bool | frozenset[type]
210+
if isinstance(include_tags, bool):
211+
self.include_tags = include_tags
212+
else:
213+
self.include_tags = frozenset(include_tags)
207214
self.transpose = transpose
208215

209216
def _value_equality_values_(self) -> Any:
@@ -217,7 +224,11 @@ def _value_equality_values_(self) -> Any:
217224
if self.label_map is None
218225
else tuple(sorted(self.label_map.items(), key=lambda e: e[0]))
219226
),
220-
self.include_tags,
227+
(
228+
self.include_tags
229+
if isinstance(self.include_tags, bool)
230+
else tuple(sorted(self.include_tags, key=lambda c: c.__name__))
231+
),
221232
self.transpose,
222233
)
223234

@@ -229,10 +240,27 @@ def __repr__(self) -> str:
229240
f'use_unicode_characters={self.use_unicode_characters!r}, '
230241
f'precision={self.precision!r}, '
231242
f'label_map={self.label_map!r}, '
232-
f'include_tags={self.include_tags!r}, '
243+
f'include_tags={self._include_tags_repr()}, '
233244
f'transpose={self.transpose!r})'
234245
)
235246

247+
def _include_tags_repr(self) -> str:
248+
if isinstance(self.include_tags, bool):
249+
return repr(self.include_tags)
250+
items = []
251+
for cls in self.include_tags:
252+
if cls.__module__ == 'builtins':
253+
items.append(cls.__qualname__)
254+
else:
255+
items.append(f"{cls.__module__}.{cls.__qualname__}")
256+
joined = ', '.join(items)
257+
return f'frozenset({{{joined}}})'
258+
259+
def tags_to_include(self, tags: Iterable[Any]) -> list[Any]:
260+
if isinstance(self.include_tags, bool):
261+
return list(tags) if self.include_tags else []
262+
return [t for t in tags if any(isinstance(t, cls) for cls in self.include_tags)]
263+
236264
def format_real(self, val: sympy.Basic | int | float) -> str:
237265
if isinstance(val, sympy.Basic):
238266
return str(val)
@@ -279,6 +307,7 @@ def copy(self):
279307
use_unicode_characters=self.use_unicode_characters,
280308
precision=self.precision,
281309
label_map=self.label_map,
310+
include_tags=self.include_tags,
282311
transpose=self.transpose,
283312
)
284313

cirq-core/cirq/protocols/circuit_diagram_info_protocol_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
import cirq
2222

2323

24+
class CustomTag:
25+
pass
26+
27+
2428
def test_circuit_diagram_info_value_wrapping() -> None:
2529
single_info = cirq.CircuitDiagramInfo(('Single',))
2630

@@ -174,6 +178,26 @@ def test_circuit_diagram_info_args_eq() -> None:
174178
include_tags=False,
175179
)
176180
)
181+
eq.add_equality_group(
182+
cirq.CircuitDiagramInfoArgs(
183+
known_qubits=cirq.LineQubit.range(2),
184+
known_qubit_count=2,
185+
use_unicode_characters=False,
186+
precision=None,
187+
label_map=None,
188+
include_tags={str},
189+
)
190+
)
191+
eq.add_equality_group(
192+
cirq.CircuitDiagramInfoArgs(
193+
known_qubits=cirq.LineQubit.range(2),
194+
known_qubit_count=2,
195+
use_unicode_characters=False,
196+
precision=None,
197+
label_map=None,
198+
include_tags={CustomTag},
199+
)
200+
)
177201
eq.add_equality_group(
178202
cirq.CircuitDiagramInfoArgs(
179203
known_qubits=cirq.LineQubit.range(2),
@@ -208,6 +232,30 @@ def test_circuit_diagram_info_args_repr() -> None:
208232
)
209233
)
210234

235+
cirq.testing.assert_equivalent_repr(
236+
cirq.CircuitDiagramInfoArgs(
237+
known_qubits=cirq.LineQubit.range(1),
238+
known_qubit_count=1,
239+
use_unicode_characters=False,
240+
precision=None,
241+
label_map=None,
242+
include_tags={str},
243+
transpose=False,
244+
)
245+
)
246+
247+
cirq.testing.assert_equivalent_repr(
248+
cirq.CircuitDiagramInfoArgs(
249+
known_qubits=cirq.LineQubit.range(1),
250+
known_qubit_count=1,
251+
use_unicode_characters=False,
252+
precision=None,
253+
label_map=None,
254+
include_tags={CustomTag},
255+
transpose=False,
256+
)
257+
)
258+
211259

212260
def test_format_real() -> None:
213261
args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT.copy()

0 commit comments

Comments
 (0)