Skip to content

Commit 3277fd3

Browse files
committed
[7/N] Add sugar syntax for module.update
Pull Request resolved: #11534 The update API in method is supposed to be portable, but we can make it more user friendly for the update API in module. Add a bit sugar syntax in module to improve UX. Such that user can update backend option in module like following: ``` Module module(stub_model_path_); int new_num_threads = 4; const auto update_result = module.update("forward", { {"StubBackend", {{IntKey("NumberOfThreads"), new_num_threads}} }, ); ``` ghstack-source-id: 290372285 @exported-using-ghexport Differential Revision: [D76242292](https://our.internmc.facebook.com/intern/diff/D76242292/)
1 parent 74ee780 commit 3277fd3

File tree

3 files changed

+95
-8
lines changed

3 files changed

+95
-8
lines changed

extension/module/module.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,5 +317,46 @@ runtime::Error Module::update(
317317
return method->update(backend_options);
318318
}
319319

320+
runtime::Error Module::update(
321+
runtime::ArrayRef<runtime::Entry> backend_options) {
322+
return update("forward", backend_options);
323+
}
324+
325+
runtime::Error Module::update(
326+
const std::string& method_name,
327+
const std::unordered_map<std::string, std::vector<runtime::BackendOption>>&
328+
backend_options) {
329+
std::vector<runtime::Entry> entries;
330+
entries.reserve(backend_options.size());
331+
332+
for (const auto& [backend_name, options] : backend_options) {
333+
entries.push_back(
334+
{backend_name.c_str(),
335+
runtime::ArrayRef<runtime::BackendOption>(
336+
options.data(), options.size())});
337+
}
338+
339+
return update(
340+
method_name,
341+
runtime::ArrayRef<runtime::Entry>(entries.data(), entries.size()));
342+
}
343+
344+
runtime::Error Module::update(
345+
const std::unordered_map<std::string, std::vector<runtime::BackendOption>>&
346+
backend_options) {
347+
std::vector<runtime::Entry> entries;
348+
entries.reserve(backend_options.size());
349+
350+
for (const auto& [backend_name, options] : backend_options) {
351+
entries.push_back(
352+
{backend_name.c_str(),
353+
runtime::ArrayRef<runtime::BackendOption>(
354+
options.data(), options.size())});
355+
}
356+
357+
return update(
358+
runtime::ArrayRef<runtime::Entry>(entries.data(), entries.size()));
359+
}
360+
} // namespace ET_MODULE_NAMESPACE
320361
} // namespace extension
321362
} // namespace executorch

extension/module/module.h

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <unordered_set>
1515
#include <vector>
1616

17+
#include <executorch/runtime/backend/backend_options.h>
18+
#include <executorch/runtime/backend/backend_options_map.h>
1719
#include <executorch/runtime/executor/program.h>
1820

1921
#ifdef USE_ATEN_LIB
@@ -487,10 +489,41 @@ class Module {
487489
*
488490
* @returns An Error to indicate success or failure.
489491
*/
490-
ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update(
491-
runtime::ArrayRef<runtime::Entry> backend_options) {
492-
return update("forward", backend_options);
493-
}
492+
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
493+
runtime::ArrayRef<runtime::Entry> backend_options);
494+
495+
/**
496+
* EXPERIMENTAL: Updates backend options for a specific method.
497+
* Loads the program and method before updating if needed. It uses simple
498+
* std library like unordered_map to store backend options.
499+
*
500+
* @param[in] method_name The name of the method to update.
501+
* @param[in] backend_options A map of <backend_name,
502+
* vector<backend_options>>.
503+
*
504+
* @returns An Error to indicate success or failure.
505+
*/
506+
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
507+
const std::string& method_name,
508+
const std::unordered_map<
509+
std::string,
510+
std::vector<runtime::BackendOption>>& backend_options);
511+
512+
/**
513+
* EXPERIMENTAL: Updates backend options for a specific method.
514+
* Loads the program and method before updating if needed. It uses simple
515+
* std library like unordered_map to store backend options.
516+
*
517+
* @param[in] method_name The name of the method to update.
518+
* @param[in] backend_options A map of <backend_name,
519+
* vector<backend_options>>.
520+
*
521+
* @returns An Error to indicate success or failure.
522+
*/
523+
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
524+
const std::unordered_map<
525+
std::string,
526+
std::vector<runtime::BackendOption>>& backend_options);
494527

495528
/**
496529
* Retrieves the EventTracer instance being used by the Module.

extension/module/test/module_test.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
#include <executorch/extension/data_loader/file_data_loader.h>
1717
#include <executorch/extension/tensor/tensor.h>
18-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1918
#include <executorch/runtime/backend/backend_options.h>
2019
#include <executorch/runtime/backend/backend_options_map.h>
20+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
2121
#include <executorch/runtime/executor/test/stub_backend.h>
2222

2323
using namespace ::executorch::extension;
@@ -33,7 +33,7 @@ class ModuleTest : public ::testing::Test {
3333
add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
3434
add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
3535
stub_model_path_ = std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH");
36-
36+
3737
// Register the StubBackend for testing
3838
StubBackend::register_singleton();
3939
}
@@ -492,7 +492,6 @@ TEST_F(ModuleTest, TestUpdate) {
492492
EXPECT_EQ(update_result, Error::Ok);
493493

494494
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
495-
496495
}
497496

498497
TEST_F(ModuleTest, TestUpdateNonExistentMethod) {
@@ -503,8 +502,22 @@ TEST_F(ModuleTest, TestUpdateNonExistentMethod) {
503502
int new_num_threads = 4;
504503
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
505504
map.add("StubBackend", backend_options.view());
506-
505+
507506
// Test update method with non-existent method name
508507
const auto update_result = module.update("nonexistent", map.entries());
509508
EXPECT_NE(update_result, Error::Ok);
510509
}
510+
511+
TEST_F(ModuleTest, TestUpdateSugarSyntax) {
512+
Module module(stub_model_path_);
513+
int new_num_threads = 4;
514+
515+
// Using std::unordered_map and std::vector directly
516+
std::unordered_map<std::string, std::vector<BackendOption>> backend_options =
517+
{{"StubBackend", {{"NumberOfThreads", new_num_threads}}}};
518+
519+
const auto update_result = module.update("forward", backend_options);
520+
521+
EXPECT_EQ(update_result, Error::Ok);
522+
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
523+
}

0 commit comments

Comments
 (0)