23
23
#include < executorch/extension/data_loader/buffer_data_loader.h>
24
24
#include < executorch/extension/data_loader/mmap_data_loader.h>
25
25
#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26
+ #include < executorch/extension/module/bundled_module.h>
26
27
#include < executorch/extension/threadpool/threadpool.h>
27
28
#include < executorch/runtime/backend/interface.h>
28
29
#include < executorch/runtime/core/data_loader.h>
@@ -425,13 +426,54 @@ inline std::unique_ptr<Module> load_module_from_file(
425
426
program_verification);
426
427
}
427
428
429
+ inline py::list get_outputs_as_py_list (
430
+ const std::vector<EValue>& outputs,
431
+ bool clone_outputs = true ) {
432
+ const auto outputs_size = outputs.size ();
433
+ py::list list (outputs_size);
434
+ for (size_t i = 0 ; i < outputs_size; ++i) {
435
+ auto & v = outputs[i];
436
+ if (Tag::None == v.tag ) {
437
+ list[i] = py::none ();
438
+ } else if (Tag::Int == v.tag ) {
439
+ list[i] = py::cast (v.toInt ());
440
+ } else if (Tag::Double == v.tag ) {
441
+ list[i] = py::cast (v.toDouble ());
442
+ } else if (Tag::Bool == v.tag ) {
443
+ list[i] = py::cast (v.toBool ());
444
+ } else if (Tag::String == v.tag ) {
445
+ list[i] = py::cast (std::string (v.toString ().data ()));
446
+ } else if (Tag::Tensor == v.tag ) {
447
+ #ifdef USE_ATEN_LIB
448
+ // Clone so the outputs in python do not share a lifetime with the
449
+ // module object
450
+ if (clone_outputs) {
451
+ list[i] = py::cast (v.toTensor ().clone ());
452
+ } else {
453
+ list[i] = py::cast (v.toTensor ());
454
+ }
455
+ #else
456
+ if (clone_outputs) {
457
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
458
+ } else {
459
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
460
+ }
461
+ #endif
462
+ } else {
463
+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
464
+ }
465
+ }
466
+ return list;
467
+ }
468
+
428
469
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
429
470
430
- struct PyBundledModule final {
471
+ struct PyBundledModule : public BundledModule {
431
472
explicit PyBundledModule (
432
473
const py::bytes& buffer,
433
474
uint32_t bundled_input_pool_size)
434
- : bundled_program_ptr_(buffer),
475
+ : BundledModule(buffer.cast<std::string_view>().data()),
476
+ bundled_program_ptr_(buffer),
435
477
program_ptr_(static_cast <const void *>(
436
478
bundled_program_flatbuffer::GetBundledProgram (
437
479
get_bundled_program_ptr ())
@@ -460,6 +502,33 @@ struct PyBundledModule final {
460
502
return program_len_;
461
503
}
462
504
505
+ py::list verify_result_with_bundled_expected_output (
506
+ const std::string& method_name,
507
+ size_t testset_idx,
508
+ double rtol = 1e-5 ,
509
+ double atol = 1e-8 ) {
510
+ // Execute the method
511
+ auto result = BundledModule::execute (method_name, testset_idx);
512
+ if (!result.ok ()) {
513
+ THROW_IF_ERROR (
514
+ result.error (),
515
+ " Method execution failed with status 0x%" PRIx32,
516
+ static_cast <uint32_t >(result.error ()));
517
+ }
518
+
519
+ // Convert outputs to py::list
520
+ const auto & outputs = result.get ();
521
+ py::list py_outputs = get_outputs_as_py_list (outputs);
522
+
523
+ Error status = BundledModule::verify_method_outputs (
524
+ method_name, testset_idx, rtol, atol);
525
+ THROW_IF_ERROR (
526
+ status,
527
+ " Result verification failed with status %" PRIu32,
528
+ static_cast <uint32_t >(status));
529
+ return py_outputs;
530
+ }
531
+
463
532
private:
464
533
// Store the bytes object instead of a raw pointer so that this module will
465
534
// keep the bytes alive.
@@ -816,43 +885,6 @@ struct PyModule final {
816
885
}
817
886
}
818
887
819
- void load_bundled_input (
820
- PyBundledModule& m,
821
- const std::string method_name,
822
- size_t testset_idx) {
823
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
824
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
825
- module_->get_method (method_name), bundled_program_ptr, testset_idx);
826
- THROW_IF_ERROR (
827
- status,
828
- " load_bundled_input failed with status 0x%" PRIx32,
829
- static_cast <uint32_t >(status));
830
- }
831
-
832
- py::list verify_result_with_bundled_expected_output (
833
- PyBundledModule& m,
834
- const std::string method_name,
835
- size_t testset_idx,
836
- double rtol = 1e-5 ,
837
- double atol = 1e-8 ) {
838
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839
- auto & method = module_->get_method (method_name);
840
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
841
- method, bundled_program_ptr, testset_idx);
842
- THROW_IF_ERROR (
843
- status,
844
- " load_bundled_input failed with status 0x%" PRIx32,
845
- static_cast <uint32_t >(status));
846
- py::list outputs = plan_execute (method_name);
847
- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
848
- method, bundled_program_ptr, testset_idx, rtol, atol);
849
- THROW_IF_ERROR (
850
- status,
851
- " Result verification failed with status %" PRIu32,
852
- static_cast <uint32_t >(status));
853
- return outputs;
854
- }
855
-
856
888
py::list plan_execute (
857
889
const std::string method_name,
858
890
bool clone_outputs = true ) {
@@ -875,46 +907,6 @@ struct PyModule final {
875
907
return get_outputs_as_py_list (outputs, clone_outputs);
876
908
}
877
909
878
- py::list get_outputs_as_py_list (
879
- const std::vector<EValue>& outputs,
880
- bool clone_outputs = true ) {
881
- const auto outputs_size = outputs.size ();
882
- py::list list (outputs_size);
883
- for (size_t i = 0 ; i < outputs_size; ++i) {
884
- auto & v = outputs[i];
885
- if (Tag::None == v.tag ) {
886
- list[i] = py::none ();
887
- } else if (Tag::Int == v.tag ) {
888
- list[i] = py::cast (v.toInt ());
889
- } else if (Tag::Double == v.tag ) {
890
- list[i] = py::cast (v.toDouble ());
891
- } else if (Tag::Bool == v.tag ) {
892
- list[i] = py::cast (v.toBool ());
893
- } else if (Tag::String == v.tag ) {
894
- list[i] = py::cast (std::string (v.toString ().data ()));
895
- } else if (Tag::Tensor == v.tag ) {
896
- #ifdef USE_ATEN_LIB
897
- // Clone so the outputs in python do not share a lifetime with the
898
- // module object
899
- if (clone_outputs) {
900
- list[i] = py::cast (v.toTensor ().clone ());
901
- } else {
902
- list[i] = py::cast (v.toTensor ());
903
- }
904
- #else
905
- if (clone_outputs) {
906
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
907
- } else {
908
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
909
- }
910
- #endif
911
- } else {
912
- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
913
- }
914
- }
915
- return list;
916
- }
917
-
918
910
std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
919
911
auto & method = module_->get_method (method_name);
920
912
return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1074,16 +1066,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1074
1066
call_guard);
1075
1067
1076
1068
py::class_<PyModule>(m, " ExecuTorchModule" )
1077
- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1078
- .def (
1079
- " verify_result_with_bundled_expected_output" ,
1080
- &PyModule::verify_result_with_bundled_expected_output,
1081
- py::arg (" bundle" ),
1082
- py::arg (" method_name" ),
1083
- py::arg (" testset_idx" ),
1084
- py::arg (" rtol" ) = 1e-5 ,
1085
- py::arg (" atol" ) = 1e-8 ,
1086
- call_guard)
1087
1069
.def (
1088
1070
" plan_execute" ,
1089
1071
&PyModule::plan_execute,
@@ -1129,7 +1111,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1129
1111
py::arg (" clone_outputs" ) = true ,
1130
1112
call_guard);
1131
1113
1132
- py::class_<PyBundledModule>(m, " BundledModule" );
1114
+ py::class_<PyBundledModule>(m, " BundledModule" )
1115
+ .def (
1116
+ " verify_result_with_bundled_expected_output" ,
1117
+ &PyBundledModule::verify_result_with_bundled_expected_output,
1118
+ py::arg (" method_name" ),
1119
+ py::arg (" testset_idx" ),
1120
+ py::arg (" rtol" ) = 1e-5 ,
1121
+ py::arg (" atol" ) = 1e-8 ,
1122
+ call_guard);
1123
+
1133
1124
py::class_<PyTensorInfo>(m, " TensorInfo" )
1134
1125
.def (" sizes" , &PyTensorInfo::sizes, call_guard)
1135
1126
.def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments