Skip to content

Commit 04434ea

Browse files
committed
register DataTransfer to Env
1 parent f7fd3b5 commit 04434ea

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

onnxruntime/core/session/inference_session.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,10 +854,26 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
854854
VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type;
855855
auto p_data_xfr = p_exec_provider->GetDataTransfer();
856856
if (p_data_xfr) {
857+
// Register with session's data transfer manager
857858
auto st = data_transfer_mgr_.RegisterDataTransfer(std::move(p_data_xfr));
858859
if (!st.IsOK()) {
859860
return st;
860861
}
862+
863+
// For WebGPU EP, also register with environment's data transfer manager
864+
// so that CopyTensors C API can work (it only checks environment's DTM)
865+
if (provider_type == kWebGpuExecutionProvider) {
866+
auto p_data_xfr_env = p_exec_provider->GetDataTransfer();
867+
if (p_data_xfr_env) {
868+
auto& env_data_transfer_mgr = environment_.GetDataTransferManager();
869+
auto st_env = const_cast<DataTransferManager&>(env_data_transfer_mgr).RegisterDataTransfer(std::move(p_data_xfr_env));
870+
if (!st_env.IsOK()) {
871+
LOGS(*session_logger_, WARNING) << "Failed to register WebGPU data transfer with environment: " << st_env.ErrorMessage();
872+
} else {
873+
VLOGS(*session_logger_, 1) << "Registered WebGPU data transfer with environment for CopyTensors API support";
874+
}
875+
}
876+
}
861877
}
862878

863879
auto p_external_data_loader = p_exec_provider->GetExternalDataLoader();

0 commit comments

Comments
 (0)