Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@
const DataTransferManager& GetDataTransferManager() const {
return data_transfer_mgr_;
}

// Register a data transfer for an execution provider with the environment's data transfer manager
// This is needed for EPs like WebGPU where CopyTensors C API needs access to the data transfer
Status RegisterDataTransferForEP(std::unique_ptr<IDataTransfer> data_transfer) {
return data_transfer_mgr_.RegisterDataTransfer(std::move(data_transfer));

Check warning on line 161 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:161: Add #include <utility> for move [build/include_what_you_use] [4]
}
#endif // !defined(ORT_MINIMAL_BUILD)

// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,25 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type;
auto p_data_xfr = p_exec_provider->GetDataTransfer();
if (p_data_xfr) {
// Register with session's data transfer manager
auto st = data_transfer_mgr_.RegisterDataTransfer(std::move(p_data_xfr));
if (!st.IsOK()) {
return st;
}

#if !defined(ORT_MINIMAL_BUILD)
// For WebGPU EP, also register with environment's data transfer manager
// so that CopyTensors C API can work (it only checks environment's DTM)
if (provider_type == kWebGpuExecutionProvider) {
auto p_data_xfr_env = p_exec_provider->GetDataTransfer();
if (p_data_xfr_env) {
auto st_env = const_cast<Environment&>(environment_).RegisterDataTransferForEP(std::move(p_data_xfr_env));
if (!st_env.IsOK()) {
LOGS(*session_logger_, WARNING) << "Failed to register WebGPU data transfer with environment: " << st_env.ErrorMessage();
}
}
}
#endif
}

auto p_external_data_loader = p_exec_provider->GetExternalDataLoader();
Expand Down
Loading