Skip to content

Commit 422a1d0

Browse files
authored
Increase flexibility of stim.Circuit.diagram filter_coords args (#618)
- Fix the fact that it doesn't take DemTarget instances - Allow passing strings like "D5" and "L0" - Allowing passing an individual filter instead of a list - Allowing saying 'detslice' instead of 'detslice-svg' Fixes #616
1 parent 4b14127 commit 422a1d0

File tree

4 files changed

+95
-21
lines changed

4 files changed

+95
-21
lines changed

src/stim/circuit/circuit_pybind_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,3 +1547,29 @@ def test_shortest_graphlike_error_many_obs():
15471547
OBSERVABLE_INCLUDE(1200) rec[-1]
15481548
""")
15491549
assert len(c.shortest_graphlike_error()) == 5
1550+
1551+
1552+
def test_detslice_filter_coords_flexibility():
1553+
c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
1554+
d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1)])
1555+
d2 = c.diagram("detslice-svg", filter_coords=stim.DemTarget.relative_detector_id(1))
1556+
d3 = c.diagram("detslice", filter_coords=["D1"])
1557+
d4 = c.diagram("detslice", filter_coords="D1")
1558+
d5 = c.diagram("detector-slice-svg", filter_coords=[3, 0])
1559+
d6 = c.diagram("detslice-svg", filter_coords=[[3, 0]])
1560+
assert str(d1) == str(d2)
1561+
assert str(d1) == str(d3)
1562+
assert str(d1) == str(d4)
1563+
assert str(d1) == str(d5)
1564+
assert str(d1) == str(d6)
1565+
assert str(d1) != str(c.diagram("detslice", filter_coords="L0"))
1566+
1567+
d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1), stim.DemTarget.relative_detector_id(3), stim.DemTarget.relative_detector_id(5), "D7"])
1568+
d2 = c.diagram("detslice", filter_coords=["D1", "D3", "D5", "D7"])
1569+
d3 = c.diagram("detslice-svg", filter_coords=[3,])
1570+
d4 = c.diagram("detslice-svg", filter_coords=[[3,]])
1571+
d5 = c.diagram("detslice-svg", filter_coords=[[3, 0], [3, 1], [3, 2], [3, 3]])
1572+
assert str(d1) == str(d2)
1573+
assert str(d1) == str(d3)
1574+
assert str(d1) == str(d4)
1575+
assert str(d1) == str(d5)

src/stim/cmd/command_diagram.pybind.cc

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#include "stim/cmd/command_diagram.pybind.h"
1616

17+
#include "stim/arg_parse.h"
1718
#include "stim/cmd/command_help.h"
19+
#include "stim/dem/detector_error_model_target.pybind.h"
1820
#include "stim/diagram/base64.h"
1921
#include "stim/diagram/crumble.h"
2022
#include "stim/diagram/detector_slice/detector_slice_set.h"
@@ -143,31 +145,68 @@ DiagramHelper stim_pybind::dem_diagram(const DetectorErrorModel &dem, const std:
143145
throw std::invalid_argument("Unrecognized diagram type: " + type);
144146
}
145147
}
148+
149+
CoordFilter item_to_filter_single(const pybind11::handle &obj) {
150+
if (pybind11::isinstance<ExposedDemTarget>(obj)) {
151+
CoordFilter filter;
152+
filter.exact_target = pybind11::cast<ExposedDemTarget>(obj).internal();
153+
filter.use_target = true;
154+
return filter;
155+
}
156+
157+
try {
158+
std::string text = pybind11::cast<std::string>(obj);
159+
if (text.size() > 1 && text[0] == 'D') {
160+
CoordFilter filter;
161+
filter.exact_target = DemTarget::relative_detector_id(parse_exact_uint64_t_from_string(text.substr(1)));
162+
filter.use_target = true;
163+
return filter;
164+
}
165+
if (text.size() > 1 && text[0] == 'L') {
166+
CoordFilter filter;
167+
filter.exact_target = DemTarget::observable_id(parse_exact_uint64_t_from_string(text.substr(1)));
168+
filter.use_target = true;
169+
return filter;
170+
}
171+
} catch (const pybind11::cast_error &) {
172+
} catch (const std::invalid_argument &) {
173+
}
174+
175+
CoordFilter filter;
176+
for (const auto &c : obj) {
177+
filter.coordinates.push_back(pybind11::cast<double>(c));
178+
}
179+
return filter;
180+
}
181+
182+
std::vector<CoordFilter> item_to_filter_multi(const pybind11::object &obj) {
183+
if (obj.is_none()) {
184+
return {CoordFilter{}};
185+
}
186+
187+
try {
188+
return {item_to_filter_single(obj)};
189+
} catch (const pybind11::cast_error &) {
190+
} catch (const std::invalid_argument &) {
191+
}
192+
193+
std::vector<CoordFilter> filters;
194+
for (const auto &filter_case : obj) {
195+
filters.push_back(item_to_filter_single(filter_case));
196+
}
197+
return filters;
198+
}
199+
146200
DiagramHelper stim_pybind::circuit_diagram(
147201
const Circuit &circuit,
148202
const std::string &type,
149203
const pybind11::object &tick,
150204
const pybind11::object &filter_coords_obj) {
151205
std::vector<CoordFilter> filter_coords;
152206
try {
153-
if (filter_coords_obj.is_none()) {
154-
filter_coords.push_back({});
155-
} else {
156-
for (const auto &filter_case : filter_coords_obj) {
157-
CoordFilter filter;
158-
if (pybind11::isinstance<DemTarget>(filter_case)) {
159-
filter.exact_target = pybind11::cast<DemTarget>(filter_case);
160-
filter.use_target = true;
161-
} else {
162-
for (const auto &c : filter_case) {
163-
filter.coordinates.push_back(pybind11::cast<double>(c));
164-
}
165-
}
166-
filter_coords.push_back(std::move(filter));
167-
}
168-
}
207+
filter_coords = item_to_filter_multi(filter_coords_obj);
169208
} catch (const std::exception &_) {
170-
throw std::invalid_argument("filter_coords wasn't a list of list of floats.");
209+
throw std::invalid_argument("filter_coords wasn't an Iterable[stim.DemTarget | Iterable[float]].");
171210
}
172211

173212
uint64_t tick_min;
@@ -198,21 +237,21 @@ DiagramHelper stim_pybind::circuit_diagram(
198237
std::stringstream out;
199238
out << DiagramTimelineAsciiDrawer::make_diagram(circuit);
200239
return DiagramHelper{DIAGRAM_TYPE_TEXT, out.str()};
201-
} else if (type == "timeline-svg") {
240+
} else if (type == "timeline-svg" || type == "timeline") {
202241
std::stringstream out;
203242
DiagramTimelineSvgDrawer::make_diagram_write_to(
204243
circuit, out, tick_min, num_ticks, SVG_MODE_TIMELINE, filter_coords);
205244
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
206-
} else if (type == "time-slice-svg" || type == "timeslice-svg") {
245+
} else if (type == "time-slice-svg" || type == "timeslice-svg" || type == "timeslice" || type == "time-slice") {
207246
std::stringstream out;
208247
DiagramTimelineSvgDrawer::make_diagram_write_to(
209248
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_SLICE, filter_coords);
210249
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
211-
} else if (type == "detslice-svg" || type == "detector-slice-svg") {
250+
} else if (type == "detslice-svg" || type == "detslice" || type == "detector-slice-svg" || type == "detector-slice") {
212251
std::stringstream out;
213252
DetectorSliceSet::from_circuit_ticks(circuit, tick_min, num_ticks, filter_coords).write_svg_diagram_to(out);
214253
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
215-
} else if (type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
254+
} else if (type == "detslice-with-ops" || type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
216255
std::stringstream out;
217256
DiagramTimelineSvgDrawer::make_diagram_write_to(
218257
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_DETECTOR_SLICE, filter_coords);

src/stim/diagram/detector_slice/detector_slice_set.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ std::ostream &stim_draw_internal::operator<<(std::ostream &out, const DetectorSl
9898
slice.write_text_diagram_to(out);
9999
return out;
100100
}
101+
std::ostream &stim_draw_internal::operator<<(std::ostream &out, const CoordFilter &filter) {
102+
if (filter.use_target) {
103+
out << filter.exact_target;
104+
} else {
105+
out << comma_sep(filter.coordinates);
106+
}
107+
return out;
108+
}
101109

102110
void DetectorSliceSet::write_text_diagram_to(std::ostream &out) const {
103111
DiagramTimelineAsciiDrawer drawer(num_qubits, false);

src/stim/diagram/detector_slice/detector_slice_set.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Coord<2> pick_polygon_center(stim::SpanRef<const Coord<2>> coords);
8585
bool is_colinear(Coord<2> a, Coord<2> b, Coord<2> c, float atol);
8686

8787
std::ostream &operator<<(std::ostream &out, const DetectorSliceSet &slice);
88+
std::ostream &operator<<(std::ostream &out, const CoordFilter &filter);
8889

8990
} // namespace stim_draw_internal
9091

0 commit comments

Comments
 (0)