From 4e28e5cc77bccfc0664848d2304128e15f31c23f Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 17 Jun 2025 09:12:46 -0700 Subject: [PATCH] Add get_option/set_option APIs Differential Revision: [D76825663](https://our.internmc.facebook.com/intern/diff/D76825663/) [ghstack-poisoned] --- runtime/backend/backend_update.h | 67 +++++++ runtime/backend/targets.bzl | 19 ++ runtime/backend/test/backend_update_test.cpp | 187 +++++++++++++++++++ runtime/backend/test/targets.bzl | 10 + 4 files changed, 283 insertions(+) create mode 100644 runtime/backend/backend_update.h create mode 100644 runtime/backend/test/backend_update_test.cpp diff --git a/runtime/backend/backend_update.h b/runtime/backend/backend_update.h new file mode 100644 index 00000000000..0fad4195824 --- /dev/null +++ b/runtime/backend/backend_update.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include + +using executorch::runtime::BackendOptionsMap; + +namespace executorch { +namespace runtime { + +Error get_option( + executorch::runtime::Span backend_options_map) { + for (auto& entry : backend_options_map) { + const char* backend_name = entry.backend_name; + auto backend_options = entry.options; + + auto backend_class = get_backend_class(backend_name); + if (!backend_class) { + return Error::NotFound; + } + + executorch::runtime::BackendUpdateContext backend_update_context; + executorch::runtime::Span backend_options_ref( + backend_options.data(), backend_options.size()); + auto result = + backend_class->get_option(backend_update_context, backend_options_ref); + if (result != Error::Ok) { + return result; + } + } + return Error::Ok; +} + +Error set_option( + const executorch::runtime::Span backend_options_map) { + for (const auto& entry : backend_options_map) { + const char* backend_name = entry.backend_name; + auto backend_options = entry.options; + + auto backend_class = get_backend_class(backend_name); + if (!backend_class) { + return Error::NotFound; + } + + executorch::runtime::BackendUpdateContext backend_update_context; + auto update_result = + backend_class->set_option(backend_update_context, backend_options); + if (update_result != Error::Ok) { + return update_result; + } + } + return Error::Ok; +} + +} // namespace runtime +} // namespace executorch diff --git a/runtime/backend/targets.bzl b/runtime/backend/targets.bzl index c58913e2bb4..fcac89d30e2 100644 --- a/runtime/backend/targets.bzl +++ b/runtime/backend/targets.bzl @@ -28,6 +28,25 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "backend_update" + aten_suffix, + exported_headers = [ + "backend_update.h", + ], + preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:evalue" + aten_suffix, + "//executorch/runtime/core:event_tracer" + aten_suffix, + ":backend_options_map" + aten_suffix, + ":interface" + aten_suffix, + ], + ) + runtime.cxx_library( name = "interface" + aten_suffix, srcs = [ diff --git a/runtime/backend/test/backend_update_test.cpp b/runtime/backend/test/backend_update_test.cpp new file mode 100644 index 00000000000..15c3f1a5c49 --- /dev/null +++ b/runtime/backend/test/backend_update_test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; +using executorch::runtime::BackendOptions; +using executorch::runtime::BackendOptionsMap; +using executorch::runtime::BackendUpdateContext; +using executorch::runtime::BoolKey; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::IntKey; +using executorch::runtime::OptionKey; +using executorch::runtime::register_backend; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::StrKey; + +// Mock backend for testing +class StubBackend : public BackendInterface { + public: + ~StubBackend() override = default; + + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + return Error::Ok; + } + + Error get_option( + BackendUpdateContext& context, + executorch::runtime::Span& + backend_options) override { + // For testing purposes, just record that get_option was called + // and verify the input parameters + get_option_called = true; + get_option_call_count++; + last_get_option_size = backend_options.size(); + + // Verify that the expected option key is present and modify the value + for (size_t i = 0; i < backend_options.size(); ++i) { + if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) { + // Set the value to what was stored by set_option + backend_options[i].value = last_num_threads; + found_expected_key = true; + break; + } + } + + return Error::Ok; + } + + Error set_option( + BackendUpdateContext& context, + const Span& backend_options) override { + // Store the options for verification + last_options_size = backend_options.size(); + if (backend_options.size() > 0) { + for (const auto& option : backend_options) { + if (strcmp(option.key, "NumberOfThreads") == 0) { + if (auto* val = std::get_if(&option.value)) { + last_num_threads = *val; + } + } + } + } + return Error::Ok; + } + + // Mutable for testing verification + size_t last_options_size = 0; + int last_num_threads = 0; + bool get_option_called = false; + int get_option_call_count = 0; + size_t last_get_option_size = 0; + bool found_expected_key = false; +}; + +class BackendUpdateTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Register the stub backend + stub_backend = std::make_unique(); + Backend backend_config{"StubBackend", stub_backend.get()}; + auto register_result = register_backend(backend_config); + ASSERT_EQ(register_result, Error::Ok); + } + + std::unique_ptr stub_backend; +}; + +// Test basic string functionality +TEST_F(BackendUpdateTest, TestSetOption) { + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); + map.add("StubBackend", backend_options.view()); + + auto status = set_option(map.entries()); + ASSERT_EQ(status, Error::Ok); + + // Verify the map contains the expected data + ASSERT_EQ(map.size(), 1); + auto options = map.get("StubBackend"); + ASSERT_EQ(options.size(), 1); + + // Verify that the backend actually received the options + ASSERT_EQ(stub_backend->last_options_size, 1); + ASSERT_EQ(stub_backend->last_num_threads, new_num_threads); +} + +// Test get_option functionality +TEST_F(BackendUpdateTest, TestGetOption) { + // First, set some options in the backend + BackendOptionsMap<3> set_map; + BackendOptions<1> set_backend_options; + int expected_num_threads = 8; + set_backend_options.set_option( + IntKey("NumberOfThreads"), expected_num_threads); + set_map.add("StubBackend", set_backend_options.view()); + + auto set_status = set_option(set_map.entries()); + ASSERT_EQ(set_status, Error::Ok); + ASSERT_EQ(stub_backend->last_num_threads, expected_num_threads); + + // Reset get_option tracking variables + stub_backend->get_option_called = false; + stub_backend->get_option_call_count = 0; + stub_backend->found_expected_key = false; + + // Now create a map with options for get_option to process + BackendOptionsMap<3> get_map; + BackendOptions<1> get_backend_options; + get_backend_options.set_option(IntKey("NumberOfThreads"), 0); + get_map.add("StubBackend", get_backend_options.view()); + + // Call get_option to test the API + auto get_status = get_option(get_map.entries()); + ASSERT_EQ(get_status, Error::Ok); + + ASSERT_TRUE( + std::get(get_map.entries()[0].options[0].value) == + expected_num_threads); + + // // Verify that the backend's get_option method was called correctly + // ASSERT_TRUE(stub_backend->get_option_called); + // ASSERT_EQ(stub_backend->get_option_call_count, 1); + // ASSERT_EQ(stub_backend->last_get_option_size, 1); + // ASSERT_TRUE(stub_backend->found_expected_key); +} diff --git a/runtime/backend/test/targets.bzl b/runtime/backend/test/targets.bzl index 5d1ef2d5c81..f5c84f8c579 100644 --- a/runtime/backend/test/targets.bzl +++ b/runtime/backend/test/targets.bzl @@ -24,6 +24,16 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "backend_update_test", + srcs = ["backend_update_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/backend:backend_options_map", + "//executorch/runtime/backend:backend_update", + ], + ) + runtime.cxx_test( name = "backend_interface_update_test", srcs = ["backend_interface_update_test.cpp"],