Skip to content

Commit aaa0292

Browse files
Compile detector method inside Simplex and Tesseract decoder (#76)
Address issue #75
1 parent 9844046 commit aaa0292

File tree

5 files changed

+67
-3
lines changed

5 files changed

+67
-3
lines changed

src/py/shared_decoding_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,17 @@ def shared_test_decode_from_detection_events(decoder_class, config_class):
202202
assert isinstance(predicted3, np.ndarray)
203203
assert predicted3.dtype.type == np.bool_
204204
assert np.array_equal(predicted3, np.array([False], dtype=bool))
205+
206+
def shared_test_compile_decoder(config_class, decoder_class):
207+
"""
208+
Tests the `compile_decoder` method on a config class.
209+
"""
210+
dem_string = "error(0.1) D0 D1 L0"
211+
dem = stim.DetectorErrorModel(dem_string)
212+
config = config_class(dem)
213+
214+
decoder = config.compile_decoder()
215+
216+
assert isinstance(decoder, decoder_class)
217+
assert decoder.config.dem == config.dem
218+
assert decoder.num_observables == dem.num_observables

src/py/simplex_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
shared_test_decode_complex_dem,
2727
shared_test_decode_batch,
2828
shared_test_decode_from_detection_events,
29+
shared_test_compile_decoder,
2930
)
3031

3132
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
@@ -55,6 +56,14 @@ def test_create_simplex_decoder():
5556
assert decoder.get_observables_from_errors([1]) == []
5657
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
5758

59+
@pytest.mark.parametrize(
60+
"config_class, decoder_class",
61+
[(tesseract_decoder.simplex.SimplexConfig, tesseract_decoder.simplex.SimplexDecoder)]
62+
)
63+
def test_simplex_compile_decoder(config_class, decoder_class):
64+
shared_test_compile_decoder(config_class, decoder_class)
65+
66+
5867
@pytest.mark.parametrize(
5968
"decoder_class, config_class",
6069
[(tesseract_decoder.simplex.SimplexDecoder, tesseract_decoder.simplex.SimplexConfig)]

src/py/tesseract_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
shared_test_decode_complex_dem,
2727
shared_test_decode_batch,
2828
shared_test_decode_from_detection_events,
29+
shared_test_compile_decoder,
2930
)
3031

3132
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
@@ -61,6 +62,13 @@ def test_create_decoder():
6162
assert decoder.get_observables_from_errors([1]) == []
6263
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
6364

65+
@pytest.mark.parametrize(
66+
"config_class, decoder_class",
67+
[(tesseract_decoder.tesseract.TesseractConfig, tesseract_decoder.tesseract.TesseractDecoder)]
68+
)
69+
def test_tesseract_compile_decoder(config_class, decoder_class):
70+
shared_test_compile_decoder(config_class, decoder_class)
71+
6472
@pytest.mark.parametrize(
6573
"decoder_class, config_class",
6674
[(tesseract_decoder.tesseract.TesseractDecoder, tesseract_decoder.tesseract.TesseractConfig)]

src/simplex.pybind.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
namespace py = pybind11;
2929

3030
namespace {
31+
// Helper function to compile the decoder.
32+
std::unique_ptr<SimplexDecoder> _compile_simplex_decoder_helper(const SimplexConfig& self) {
33+
return std::make_unique<SimplexDecoder>(self);
34+
}
35+
3136
SimplexConfig simplex_config_maker(py::object dem, bool parallelize = false,
3237
size_t window_length = 0, size_t window_slide_length = 0,
3338
bool verbose = false) {
@@ -51,7 +56,17 @@ void add_simplex_module(py::module& root) {
5156
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)
5257
.def_readwrite("verbose", &SimplexConfig::verbose)
5358
.def("windowing_enabled", &SimplexConfig::windowing_enabled)
54-
.def("__str__", &SimplexConfig::str);
59+
.def("__str__", &SimplexConfig::str)
60+
.def("compile_decoder", &_compile_simplex_decoder_helper,
61+
py::return_value_policy::take_ownership, R"pbdoc(
62+
Compiles the configuration into a new SimplexDecoder instance.
63+
64+
Returns
65+
-------
66+
SimplexDecoder
67+
A new SimplexDecoder instance configured with the current
68+
settings.
69+
)pbdoc");
5570

5671
py::class_<SimplexDecoder>(m, "SimplexDecoder")
5772
.def(py::init<SimplexConfig>(), py::arg("config"))

src/tesseract.pybind.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
namespace py = pybind11;
2828

2929
namespace {
30+
// Helper function to compile the decoder.
31+
std::unique_ptr<TesseractDecoder> _compile_tesseract_decoder_helper(const TesseractConfig& self) {
32+
return std::make_unique<TesseractDecoder>(self);
33+
}
34+
3035
TesseractConfig tesseract_config_maker(
3136
py::object dem, int det_beam = INF_DET_BEAM, bool beam_climbing = false,
3237
bool no_revisit_dets = false, bool at_most_two_errors_per_detector = false,
@@ -58,7 +63,18 @@ void add_tesseract_module(py::module& root) {
5863
.def_readwrite("pqlimit", &TesseractConfig::pqlimit)
5964
.def_readwrite("det_orders", &TesseractConfig::det_orders)
6065
.def_readwrite("det_penalty", &TesseractConfig::det_penalty)
61-
.def("__str__", &TesseractConfig::str);
66+
.def("__str__", &TesseractConfig::str)
67+
.def("compile_decoder", &_compile_tesseract_decoder_helper,
68+
py::return_value_policy::take_ownership,
69+
R"pbdoc(
70+
Compiles the configuration into a new `TesseractDecoder` instance.
71+
72+
Returns
73+
-------
74+
TesseractDecoder
75+
A new `TesseractDecoder` instance configured with the current
76+
settings.
77+
)pbdoc");
6278

6379
py::class_<Node>(m, "Node")
6480
.def(py::init<double, size_t, std::vector<size_t>>(), py::arg("cost") = 0.0,
@@ -225,7 +241,9 @@ void add_tesseract_module(py::module& root) {
225241
)pbdoc")
226242
.def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag)
227243
.def_readwrite("predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer)
228-
.def_readwrite("errors", &TesseractDecoder::errors);
244+
.def_readwrite("errors", &TesseractDecoder::errors)
245+
.def_readwrite("config", &TesseractDecoder::config)
246+
.def_readwrite("num_observables", &TesseractDecoder::num_observables);
229247
}
230248

231249
#endif

0 commit comments

Comments
 (0)