Skip to content

Commit 955fcef

Browse files
committed
Add get_option/set_option APIs
Pull Request resolved: #11758 ghstack-source-id: 290994800 Expose the API to set/get backend option. We can either pass in {backend_name, backend options} or {backend options map} Differential Revision: [D76825663](https://our.internmc.facebook.com/intern/diff/D76825663/)
1 parent 6be0d90 commit 955fcef

File tree

4 files changed

+340
-0
lines changed

4 files changed

+340
-0
lines changed

runtime/backend/backend_update.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <executorch/runtime/backend/backend_options_map.h>
11+
#include <executorch/runtime/backend/backend_update_context.h>
12+
#include <executorch/runtime/backend/interface.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <cstddef>
15+
#include <cstring>
16+
17+
using executorch::runtime::BackendOptionsMap;
18+
19+
namespace executorch {
20+
namespace runtime {
21+
22+
/**
23+
* Retrieves backend options for a specific backend.
24+
*
25+
* @param backend_name The name of the backend to get options from
26+
* @param backend_options The backend option objects that will be filled with
27+
* the populated values from the backend
28+
* @return Error::Ok on success, Error::NotFound if backend is not found, or
29+
* other error codes on failure
30+
*/
31+
Error get_option(
32+
const char* backend_name,
33+
executorch::runtime::Span<executorch::runtime::BackendOption>
34+
backend_options) {
35+
auto backend_class = get_backend_class(backend_name);
36+
if (!backend_class) {
37+
return Error::NotFound;
38+
}
39+
executorch::runtime::BackendUpdateContext backend_update_context;
40+
executorch::runtime::Span<BackendOption> backend_options_ref(
41+
backend_options.data(), backend_options.size());
42+
auto result =
43+
backend_class->get_option(backend_update_context, backend_options_ref);
44+
if (result != Error::Ok) {
45+
return result;
46+
}
47+
return Error::Ok;
48+
}
49+
50+
/**
51+
* Retrieves backend options for multiple backends using a backend options map.
52+
*
53+
* @param backend_options_map The backend option map containing backend names
54+
* and their associated options, which will be filled with the populated values
55+
* from the backend
56+
* @return Error::Ok on success, or the first error encountered when processing
57+
* the entries
58+
*/
59+
Error get_option(
60+
executorch::runtime::Span<executorch::runtime::Entry> backend_options_map) {
61+
Error result = Error::Ok;
62+
for (auto& entry : backend_options_map) {
63+
const char* backend_name = entry.backend_name;
64+
auto backend_options = entry.options;
65+
auto result = get_option(backend_name, backend_options);
66+
if (result != Error::Ok) {
67+
return result;
68+
}
69+
}
70+
return Error::Ok;
71+
}
72+
73+
/**
74+
* Sets backend options for a specific backend.
75+
*
76+
* @param backend_name The name of the backend to set options for
77+
* @param backend_options The backend option list containing the options
78+
* to set
79+
* @return Error::Ok on success, Error::NotFound if backend is not found, or
80+
* other error codes on failure
81+
*/
82+
Error set_option(
83+
const char* backend_name,
84+
const executorch::runtime::Span<executorch::runtime::BackendOption>
85+
backend_options) {
86+
auto backend_class = get_backend_class(backend_name);
87+
if (!backend_class) {
88+
return Error::NotFound;
89+
}
90+
91+
executorch::runtime::BackendUpdateContext backend_update_context;
92+
Error result =
93+
backend_class->set_option(backend_update_context, backend_options);
94+
if (result != Error::Ok) {
95+
return result;
96+
}
97+
return Error::Ok;
98+
}
99+
100+
/**
101+
* Sets backend options for multiple backends using a backend options map.
102+
*
103+
* @param backend_options_map The backend option map containing backend names
104+
* and their associated backend options to set
105+
* @return Error::Ok on success, or the first error encountered when processing
106+
*/
107+
Error set_option(const executorch::runtime::Span<executorch::runtime::Entry>
108+
backend_options_map) {
109+
Error result = Error::Ok;
110+
for (const auto& entry : backend_options_map) {
111+
const char* backend_name = entry.backend_name;
112+
auto backend_options = entry.options;
113+
result = set_option(backend_name, backend_options);
114+
115+
if (result != Error::Ok) {
116+
return result;
117+
}
118+
}
119+
return Error::Ok;
120+
}
121+
122+
} // namespace runtime
123+
} // namespace executorch

runtime/backend/targets.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ def define_common_targets():
2828
],
2929
)
3030

31+
runtime.cxx_library(
32+
name = "backend_update" + aten_suffix,
33+
exported_headers = [
34+
"backend_update.h",
35+
],
36+
preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
37+
visibility = [
38+
"//executorch/...",
39+
"@EXECUTORCH_CLIENTS",
40+
],
41+
exported_deps = [
42+
"//executorch/runtime/core:core",
43+
"//executorch/runtime/core:evalue" + aten_suffix,
44+
"//executorch/runtime/core:event_tracer" + aten_suffix,
45+
":backend_options_map" + aten_suffix,
46+
":interface" + aten_suffix,
47+
],
48+
)
49+
3150
runtime.cxx_library(
3251
name = "interface" + aten_suffix,
3352
srcs = [
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/backend/backend_options.h>
10+
#include <executorch/runtime/backend/backend_options_map.h>
11+
#include <executorch/runtime/backend/backend_update.h>
12+
#include <executorch/runtime/backend/interface.h>
13+
#include <executorch/runtime/core/array_ref.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::runtime::ArrayRef;
19+
using executorch::runtime::Backend;
20+
using executorch::runtime::BackendExecutionContext;
21+
using executorch::runtime::BackendInitContext;
22+
using executorch::runtime::BackendInterface;
23+
using executorch::runtime::BackendOptions;
24+
using executorch::runtime::BackendOptionsMap;
25+
using executorch::runtime::BackendUpdateContext;
26+
using executorch::runtime::BoolKey;
27+
using executorch::runtime::CompileSpec;
28+
using executorch::runtime::DelegateHandle;
29+
using executorch::runtime::Error;
30+
using executorch::runtime::EValue;
31+
using executorch::runtime::FreeableBuffer;
32+
using executorch::runtime::IntKey;
33+
using executorch::runtime::OptionKey;
34+
using executorch::runtime::register_backend;
35+
using executorch::runtime::Result;
36+
using executorch::runtime::Span;
37+
using executorch::runtime::StrKey;
38+
39+
// Mock backend for testing
40+
class StubBackend : public BackendInterface {
41+
public:
42+
~StubBackend() override = default;
43+
44+
bool is_available() const override {
45+
return true;
46+
}
47+
48+
Result<DelegateHandle*> init(
49+
BackendInitContext& context,
50+
FreeableBuffer* processed,
51+
ArrayRef<CompileSpec> compile_specs) const override {
52+
return nullptr;
53+
}
54+
55+
Error execute(
56+
BackendExecutionContext& context,
57+
DelegateHandle* handle,
58+
EValue** args) const override {
59+
return Error::Ok;
60+
}
61+
62+
Error get_option(
63+
BackendUpdateContext& context,
64+
executorch::runtime::Span<executorch::runtime::BackendOption>&
65+
backend_options) override {
66+
// For testing purposes, just record that get_option was called
67+
// and verify the input parameters
68+
get_option_called = true;
69+
get_option_call_count++;
70+
last_get_option_size = backend_options.size();
71+
72+
// Verify that the expected option key is present and modify the value
73+
for (size_t i = 0; i < backend_options.size(); ++i) {
74+
if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) {
75+
// Set the value to what was stored by set_option
76+
backend_options[i].value = last_num_threads;
77+
found_expected_key = true;
78+
break;
79+
}
80+
}
81+
82+
return Error::Ok;
83+
}
84+
85+
Error set_option(
86+
BackendUpdateContext& context,
87+
const Span<executorch::runtime::BackendOption>& backend_options)
88+
override {
89+
// Store the options for verification
90+
last_options_size = backend_options.size();
91+
if (backend_options.size() > 0) {
92+
for (const auto& option : backend_options) {
93+
if (strcmp(option.key, "NumberOfThreads") == 0) {
94+
if (auto* val = std::get_if<int>(&option.value)) {
95+
last_num_threads = *val;
96+
}
97+
}
98+
}
99+
}
100+
return Error::Ok;
101+
}
102+
103+
// Mutable for testing verification
104+
size_t last_options_size = 0;
105+
int last_num_threads = 0;
106+
bool get_option_called = false;
107+
int get_option_call_count = 0;
108+
size_t last_get_option_size = 0;
109+
bool found_expected_key = false;
110+
};
111+
112+
class BackendUpdateTest : public ::testing::Test {
113+
protected:
114+
void SetUp() override {
115+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
116+
// first.
117+
executorch::runtime::runtime_init();
118+
119+
// Register the stub backend
120+
stub_backend = std::make_unique<StubBackend>();
121+
Backend backend_config{"StubBackend", stub_backend.get()};
122+
auto register_result = register_backend(backend_config);
123+
ASSERT_EQ(register_result, Error::Ok);
124+
}
125+
126+
std::unique_ptr<StubBackend> stub_backend;
127+
};
128+
129+
// Test basic string functionality
130+
TEST_F(BackendUpdateTest, TestSetOption) {
131+
BackendOptionsMap<3> map;
132+
BackendOptions<1> backend_options;
133+
int new_num_threads = 4;
134+
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
135+
map.add("StubBackend", backend_options.view());
136+
137+
auto status = set_option(map.entries());
138+
ASSERT_EQ(status, Error::Ok);
139+
140+
// Verify the map contains the expected data
141+
ASSERT_EQ(map.size(), 1);
142+
auto options = map.get("StubBackend");
143+
ASSERT_EQ(options.size(), 1);
144+
145+
// Verify that the backend actually received the options
146+
ASSERT_EQ(stub_backend->last_options_size, 1);
147+
ASSERT_EQ(stub_backend->last_num_threads, new_num_threads);
148+
}
149+
150+
// Test get_option functionality
151+
TEST_F(BackendUpdateTest, TestGetOption) {
152+
// First, set some options in the backend
153+
BackendOptionsMap<3> set_map;
154+
BackendOptions<1> set_backend_options;
155+
int expected_num_threads = 8;
156+
set_backend_options.set_option(
157+
IntKey("NumberOfThreads"), expected_num_threads);
158+
set_map.add("StubBackend", set_backend_options.view());
159+
160+
auto set_status = set_option(set_map.entries());
161+
ASSERT_EQ(set_status, Error::Ok);
162+
ASSERT_EQ(stub_backend->last_num_threads, expected_num_threads);
163+
164+
// Reset get_option tracking variables
165+
stub_backend->get_option_called = false;
166+
stub_backend->get_option_call_count = 0;
167+
stub_backend->found_expected_key = false;
168+
169+
// Now create a map with options for get_option to process
170+
BackendOptionsMap<3> get_map;
171+
BackendOptions<1> get_backend_options;
172+
get_backend_options.set_option(IntKey("NumberOfThreads"), 0);
173+
get_map.add("StubBackend", get_backend_options.view());
174+
175+
// Call get_option to test the API
176+
auto get_status = get_option(get_map.entries());
177+
ASSERT_EQ(get_status, Error::Ok);
178+
179+
ASSERT_TRUE(
180+
std::get<int>(get_map.entries()[0].options[0].value) ==
181+
expected_num_threads);
182+
183+
// // Verify that the backend's get_option method was called correctly
184+
// ASSERT_TRUE(stub_backend->get_option_called);
185+
// ASSERT_EQ(stub_backend->get_option_call_count, 1);
186+
// ASSERT_EQ(stub_backend->last_get_option_size, 1);
187+
// ASSERT_TRUE(stub_backend->found_expected_key);
188+
}

runtime/backend/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def define_common_targets():
2424
],
2525
)
2626

27+
runtime.cxx_test(
28+
name = "backend_update_test",
29+
srcs = ["backend_update_test.cpp"],
30+
deps = [
31+
"//executorch/runtime/core:core",
32+
"//executorch/runtime/backend:backend_options_map",
33+
"//executorch/runtime/backend:backend_update",
34+
],
35+
)
36+
2737
runtime.cxx_test(
2838
name = "backend_interface_update_test",
2939
srcs = ["backend_interface_update_test.cpp"],

0 commit comments

Comments
 (0)