diff --git a/extension/module/module.cpp b/extension/module/module.cpp index e6ead0014c2..974ebe0bada 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -317,5 +317,46 @@ runtime::Error Module::update( return method->update(backend_options); } +runtime::Error Module::update( + runtime::ArrayRef backend_options) { + return update("forward", backend_options); +} + +runtime::Error Module::update( + const std::string& method_name, + const std::unordered_map>& + backend_options) { + std::vector entries; + entries.reserve(backend_options.size()); + + for (const auto& [backend_name, options] : backend_options) { + entries.push_back( + {backend_name.c_str(), + runtime::ArrayRef( + options.data(), options.size())}); + } + + return update( + method_name, + runtime::ArrayRef(entries.data(), entries.size())); +} + +runtime::Error Module::update( + const std::unordered_map>& + backend_options) { + std::vector entries; + entries.reserve(backend_options.size()); + + for (const auto& [backend_name, options] : backend_options) { + entries.push_back( + {backend_name.c_str(), + runtime::ArrayRef( + options.data(), options.size())}); + } + + return update( + runtime::ArrayRef(entries.data(), entries.size())); +} +} // namespace ET_MODULE_NAMESPACE } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index 6002ddd7c63..1e844f54824 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -14,6 +14,8 @@ #include #include +#include +#include #include #ifdef USE_ATEN_LIB @@ -487,10 +489,41 @@ class Module { * * @returns An Error to indicate success or failure. */ - ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update( - runtime::ArrayRef backend_options) { - return update("forward", backend_options); - } + ET_EXPERIMENTAL ET_NODISCARD runtime::Error update( + runtime::ArrayRef backend_options); + + /** + * EXPERIMENTAL: Updates backend options for a specific method. + * Loads the program and method before updating if needed. It uses simple + * std library like unordered_map to store backend options. + * + * @param[in] method_name The name of the method to update. + * @param[in] backend_options A map of >. + * + * @returns An Error to indicate success or failure. + */ + ET_EXPERIMENTAL ET_NODISCARD runtime::Error update( + const std::string& method_name, + const std::unordered_map< + std::string, + std::vector>& backend_options); + + /** + * EXPERIMENTAL: Updates backend options for a specific method. + * Loads the program and method before updating if needed. It uses simple + * std library like unordered_map to store backend options. + * + * @param[in] method_name The name of the method to update. + * @param[in] backend_options A map of >. + * + * @returns An Error to indicate success or failure. + */ + ET_EXPERIMENTAL ET_NODISCARD runtime::Error update( + const std::unordered_map< + std::string, + std::vector>& backend_options); /** * Retrieves the EventTracer instance being used by the Module. diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 24476c4adab..9a44230b6dd 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -15,9 +15,9 @@ #include #include -#include #include #include +#include #include using namespace ::executorch::extension; @@ -33,7 +33,7 @@ class ModuleTest : public ::testing::Test { add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH"); add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); stub_model_path_ = std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH"); - + // Register the StubBackend for testing StubBackend::register_singleton(); } @@ -492,7 +492,6 @@ TEST_F(ModuleTest, TestUpdate) { EXPECT_EQ(update_result, Error::Ok); ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); - } TEST_F(ModuleTest, TestUpdateNonExistentMethod) { @@ -503,8 +502,22 @@ TEST_F(ModuleTest, TestUpdateNonExistentMethod) { int new_num_threads = 4; backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); map.add("StubBackend", backend_options.view()); - + // Test update method with non-existent method name const auto update_result = module.update("nonexistent", map.entries()); EXPECT_NE(update_result, Error::Ok); } + +TEST_F(ModuleTest, TestUpdateSugarSyntax) { + Module module(stub_model_path_); + int new_num_threads = 4; + + // Using std::unordered_map and std::vector directly + std::unordered_map> backend_options = + {{"StubBackend", {{"NumberOfThreads", new_num_threads}}}}; + + const auto update_result = module.update("forward", backend_options); + + EXPECT_EQ(update_result, Error::Ok); + ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); +}