Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 64 additions & 103 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,20 +642,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}

QNNExecutionProvider::~QNNExecutionProvider() {
// clean up thread local context caches
std::lock_guard<std::mutex> lock(context_state_.mutex);
for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) {
const auto cache = cache_weak.lock();
if (!cache) continue;
ORT_IGNORE_RETURN_VALUE(cache->erase(this));
}

// Unregister the ETW callback
#if defined(_WIN32)
if (callback_ETWSink_provider_ != nullptr) {
logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_);
}
#endif
{
std::lock_guard<std::mutex> lock(htp_power_config_id_mutex_);
if (managed_htp_power_config_id_) {
managed_htp_power_config_id_.reset();
}
}
}

// Logs information about the supported/unsupported nodes.
Expand Down Expand Up @@ -955,7 +953,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
if (IsNpuBackend(qnn_backend_manager_->GetQnnBackendType())) {
// Set the power config id and the default power mode from provider option for main thread,
// otherwise it will mess up the power mode if user just create session without run it.
GetPerThreadContext();
CreateHtpPowerConfigId();
}

// Report error if QNN CPU backend is loaded while CPU fallback is disabled
Expand Down Expand Up @@ -1355,93 +1353,6 @@ const InlinedVector<const Node*> QNNExecutionProvider::GetEpContextNodes() const
return ep_context_nodes;
}

QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager,
uint32_t device_id,
uint32_t core_id,
qnn::HtpPerformanceMode default_htp_performance_mode,
uint32_t default_rpc_control_latency,
uint32_t default_rpc_polling_time)
: qnn_backend_manager_(qnn_backend_manager) {
Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_);
is_htp_power_config_id_valid_ = rt.IsOK();
// default_htp_performance_mode and default_rpc_control_latency are from QNN EP option.
// set it once only for each thread as default so user don't need to set it for every session run
if (is_htp_power_config_id_valid_) {
if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_,
default_htp_performance_mode));
}
if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_,
default_rpc_control_latency,
default_rpc_polling_time));
}
}
}

QNNExecutionProvider::PerThreadContext::~PerThreadContext() {
if (is_htp_power_config_id_valid_) {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_));
}
}

QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const {
const auto& per_thread_context_cache = PerThreadContextCache();

// try to use cached context
auto cached_context_it = per_thread_context_cache->find(this);
if (cached_context_it != per_thread_context_cache->end()) {
auto cached_context = cached_context_it->second.lock();
ORT_ENFORCE(cached_context);
return *cached_context;
}

// get context and update cache
std::shared_ptr<PerThreadContext> context;
{
std::lock_guard<std::mutex> lock(context_state_.mutex);

// get or create a context
if (context_state_.retired_context_pool.empty()) {
uint32_t core_id = 0;
context = std::make_shared<PerThreadContext>(qnn_backend_manager_.get(), device_id_, core_id,
default_htp_performance_mode_, default_rpc_control_latency_,
default_rpc_polling_time_);
} else {
context = context_state_.retired_context_pool.back();
context_state_.retired_context_pool.pop_back();
}

// insert into active_contexts, should not already be present
const auto active_contexts_insert_result = context_state_.active_contexts.insert(context);
ORT_ENFORCE(active_contexts_insert_result.second);

// insert into caches_to_update_on_destruction, may already be present
ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache));
}

per_thread_context_cache->insert(std::make_pair(this, context));

return *context;
}

void QNNExecutionProvider::ReleasePerThreadContext() const {
const auto& per_thread_context_cache = PerThreadContextCache();

auto cached_context_it = per_thread_context_cache->find(this);
ORT_ENFORCE(cached_context_it != per_thread_context_cache->end());
auto cached_context = cached_context_it->second.lock();
ORT_ENFORCE(cached_context);

{
std::lock_guard<std::mutex> lock(context_state_.mutex);
context_state_.active_contexts.erase(cached_context);
context_state_.retired_context_pool.push_back(cached_context);
}

per_thread_context_cache->erase(cached_context_it);
}

static bool TryGetConfigEntry(const ConfigOptions& config_options, const std::string& key, std::string& value) {
std::optional<std::string> new_value = config_options.GetConfigEntry(key);
if (!new_value.has_value()) {
Expand Down Expand Up @@ -1479,14 +1390,14 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio
rpc_polling_time = 9999;
}

if (GetPerThreadContext().IsHtpPowerConfigIdValid()) {
if (IsHtpPowerConfigIdValid()) {
if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(),
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetHtpPowerConfigId(),
Copy link
Contributor

@adrianlizarraga adrianlizarraga Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a session is used across multiple threads, is it possible that they can interfere with each other here (and in other places where the EP calls qnn_backend_manager_->Set*PowerConfig.)?

The mutex is only locked during the call to GetHtpPowerConfigId() but there is no synchronization here.

htp_performance_mode));
}

if (rpc_control_latency > 0 || rpc_polling_time > 0) {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(),
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetHtpPowerConfigId(),
rpc_control_latency,
rpc_polling_time));
}
Expand Down Expand Up @@ -1517,10 +1428,10 @@ Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::R
}

if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) {
if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) {
if (!IsHtpPowerConfigIdValid()) {
return Status::OK();
}
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(),
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetHtpPowerConfigId(),
htp_performance_mode));
}

Expand Down Expand Up @@ -1586,8 +1497,8 @@ Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span<const char* const> ke
}
qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
ParseHtpPerformanceMode(value, htp_performance_mode);
if (GetPerThreadContext().IsHtpPowerConfigIdValid()) {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(),
if (IsHtpPowerConfigIdValid()) {
ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetHtpPowerConfigId(),
htp_performance_mode));
}
} else {
Expand All @@ -1602,4 +1513,54 @@ Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span<const char* const> ke
return Status::OK();
}

QNNExecutionProvider::ManagedHtpPowerConfigId::ManagedHtpPowerConfigId(uint32_t htp_power_config_id,
std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager)
: htp_power_config_id_(htp_power_config_id),
qnn_backend_manager_(qnn_backend_manager) {
}

QNNExecutionProvider::ManagedHtpPowerConfigId::~ManagedHtpPowerConfigId() {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_));
}

uint32_t QNNExecutionProvider::ManagedHtpPowerConfigId::GetHtpPowerConfigId() {
return htp_power_config_id_;
}

void QNNExecutionProvider::CreateHtpPowerConfigId() const {
std::lock_guard<std::mutex> lock(htp_power_config_id_mutex_);
if (managed_htp_power_config_id_) {
return;
}

constexpr uint32_t core_id = 0;
uint32_t htp_power_config_id;

Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id_, core_id, htp_power_config_id);

if (rt == Status::OK()) {
managed_htp_power_config_id_ = std::make_shared<ManagedHtpPowerConfigId>(htp_power_config_id, qnn_backend_manager_);

if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode_) {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id,
default_htp_performance_mode_));
}
if (default_rpc_control_latency_ > 0 || default_rpc_polling_time_ > 0) {
ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id,
default_rpc_control_latency_,
default_rpc_polling_time_));
}
}
}

bool QNNExecutionProvider::IsHtpPowerConfigIdValid() {
std::lock_guard<std::mutex> lock(htp_power_config_id_mutex_);
return managed_htp_power_config_id_ != nullptr;
}

uint32_t QNNExecutionProvider::GetHtpPowerConfigId() {
std::lock_guard<std::mutex> lock(htp_power_config_id_mutex_);
return managed_htp_power_config_id_->GetHtpPowerConfigId();
}

} // namespace onnxruntime
82 changes: 22 additions & 60 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ class QNNExecutionProvider : public IExecutionProvider {
bool IsHtpSharedMemoryAllocatorAvailable() const { return rpcmem_library_ != nullptr; }

private:
class ManagedHtpPowerConfigId {
public:
ManagedHtpPowerConfigId(uint32_t htp_power_config_id, std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager);

~ManagedHtpPowerConfigId();

uint32_t GetHtpPowerConfigId();

private:
uint32_t htp_power_config_id_;
std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
};

void CreateHtpPowerConfigId() const;

bool IsHtpPowerConfigIdValid();

uint32_t GetHtpPowerConfigId();

mutable std::shared_ptr<ManagedHtpPowerConfigId> managed_htp_power_config_id_ = nullptr;
mutable std::mutex htp_power_config_id_mutex_;

qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
// Note: Using shared_ptr<QnnBackendManager> so that we can refer to it with a weak_ptr from a
// HtpSharedMemoryAllocator allocation cleanup callback.
Expand Down Expand Up @@ -114,66 +136,6 @@ class QNNExecutionProvider : public IExecutionProvider {
// Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available.
// This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators().
std::shared_ptr<qnn::RpcMemLibrary> rpcmem_library_ = nullptr;

class PerThreadContext final {
public:
PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager,
uint32_t device_id, uint32_t core_id,
qnn::HtpPerformanceMode default_htp_performance_mode,
uint32_t default_rpc_control_latency,
uint32_t default_rpc_polling_time);
~PerThreadContext();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);

bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; }

uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; }

private:
bool is_htp_power_config_id_valid_ = false;
uint32_t htp_power_config_id_ = 0;
qnn::QnnBackendManager* qnn_backend_manager_;
};

using PerThreadContextMap = std::unordered_map<const QNNExecutionProvider*, std::weak_ptr<PerThreadContext>>;

struct ContextCacheHolder {
ContextCacheHolder() {
RunOnUnload([&, weak_p_ = std::weak_ptr<PerThreadContextMap>(p)] {
if (auto lock = weak_p_.lock())
p.reset();
});
}

std::shared_ptr<PerThreadContextMap> p = std::make_shared<PerThreadContextMap>();
};

static const std::shared_ptr<PerThreadContextMap>& PerThreadContextCache() {
thread_local const ContextCacheHolder per_thread_context_cache;
return per_thread_context_cache.p;
}

struct PerThreadContextState {
// contexts that are currently active
std::set<std::shared_ptr<PerThreadContext>, std::owner_less<std::shared_ptr<PerThreadContext>>> active_contexts;
// contexts available for reuse
std::vector<std::shared_ptr<PerThreadContext>> retired_context_pool;
// weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed
// upon destruction
std::set<std::weak_ptr<PerThreadContextMap>, std::owner_less<std::weak_ptr<PerThreadContextMap>>>
caches_to_update_on_destruction;
// synchronizes access to PerThreadContextState members
std::mutex mutex;
};

// The execution provider maintains the PerThreadContexts in this structure.
// Synchronization is required to update the contained structures.
// On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time,
// so synchronization is not required for that.
mutable PerThreadContextState context_state_;

PerThreadContext& GetPerThreadContext() const;
void ReleasePerThreadContext() const;
};

} // namespace onnxruntime