Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,11 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
return *it->second.context;
}

bool WebGpuContextFactory::HasContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);
return contexts_.find(context_id) != contexts_.end();
}

void WebGpuContextFactory::ReleaseContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class WebGpuContextFactory {

static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
static WebGpuContext& GetContext(int context_id);
static bool HasContext(int context_id);

static void ReleaseContext(int context_id);

Expand Down
103 changes: 103 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,48 @@
#include "core/session/ort_apis.h"

#include "core/providers/webgpu/webgpu_provider_options.h"
#include "core/providers/webgpu/data_transfer.h"
using namespace onnxruntime::webgpu::options;

namespace onnxruntime {
// Helper to get default context config, buffer cache config, backend type, and enable_pix_capture
struct WebGpuContextParams {
webgpu::WebGpuContextConfig context_config;
webgpu::WebGpuBufferCacheConfig buffer_cache_config;
int backend_type;
bool enable_pix_capture;
};

static WebGpuContextParams GetDefaultWebGpuContextParams() {
WebGpuContextParams params;
params.context_config.context_id = 0;
params.context_config.instance = nullptr;
params.context_config.device = nullptr;
params.context_config.dawn_proc_table = nullptr;
params.context_config.validation_mode = webgpu::ValidationMode::Disabled;
params.context_config.preserve_device = false;
params.context_config.max_storage_buffer_binding_size = 0;
params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance);

params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket;
params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple;
params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled;
params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled;

#ifdef _WIN32
#if defined(DAWN_ENABLE_D3D12)
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
#elif defined(DAWN_ENABLE_VULKAN)
params.backend_type = static_cast<int>(WGPUBackendType_Vulkan);
#else
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
#endif
#else
params.backend_type = 0;
#endif
params.enable_pix_capture = false;
return params;
}

struct WebGpuProviderFactory : IExecutionProviderFactory {
WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config)
Expand Down Expand Up @@ -291,4 +330,68 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));
}

// WebGPU DataTransfer implementation wrapper for the C API
struct WebGpuDataTransferImpl : OrtDataTransferImpl {
WebGpuDataTransferImpl(const OrtApi& ort_api_in, webgpu::BufferManager& buffer_manager)
: ort_api{ort_api_in},
ep_api{*ort_api_in.GetEpApi()},
data_transfer_{buffer_manager} {
ort_version_supported = ORT_API_VERSION;
CanCopy = CanCopyImpl;
CopyTensors = CopyTensorsImpl;
Release = ReleaseImpl;
}

static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr,
const OrtMemoryDevice* src_memory_device,
const OrtMemoryDevice* dst_memory_device) noexcept {
const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr);
OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);

// WebGPU supports GPU<->GPU, GPU<->CPU copies
return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) ||
(src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) ||
(src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU);
}

static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr,
const OrtValue** src_tensors,
OrtValue** dst_tensors,
OrtSyncStream** /*streams*/,
size_t num_tensors) noexcept {
auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr);
for (size_t idx = 0; idx < num_tensors; ++idx) {
const OrtValue* src_tensor = src_tensors[idx];
OrtValue* dst_tensor = dst_tensors[idx];
auto status = impl.data_transfer_.CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>());
if (!status.IsOK()) {
// Convert common::Status to OrtStatus
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str());
}
}
return nullptr;
}

static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
delete static_cast<WebGpuDataTransferImpl*>(this_ptr);
}

const OrtApi& ort_api;
const OrtEpApi& ep_api;
webgpu::DataTransfer data_transfer_;
};

OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id) {
webgpu::WebGpuContext* context_ptr = nullptr;
if (webgpu::WebGpuContextFactory::HasContext(context_id)) {
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
} else {
WebGpuContextParams params = GetDefaultWebGpuContextParams();
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
}
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION), context_ptr->BufferManager());
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@

#include "core/providers/webgpu/webgpu_provider_options.h"

struct OrtDataTransferImpl;

namespace onnxruntime {
struct ConfigOptions;

struct WebGpuProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(const ConfigOptions& config_options);
};

// C API to create data transfer for WebGPU EP
// If the context doesn't exist, creates a default one (context_id=0)
// Caller takes ownership of the returned OrtDataTransferImpl*
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id);

} // namespace onnxruntime
17 changes: 5 additions & 12 deletions onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,13 @@ OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* co
return nullptr;
}

/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of
an InferenceSession.
OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info,
const OrtKeyValuePairs* allocator_options,
OrtAllocator** allocator) noexcept override {
*allocator = device_allocators[memory_info->device.Id()].get();
}

OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override {
// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors.
*data_transfer = nullptr;
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept {
// Call the WebGPU provider's C API to create the data transfer
// This is implemented in the WebGPU provider backend which has access to WebGPU headers
*data_transfer = OrtWebGpuCreateDataTransfer(0); // Use default context (context_id=0)
return nullptr;
}
*/

} // namespace onnxruntime

#endif // USE_WEBGPU
2 changes: 2 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class WebGpuEpFactory : public EpFactoryInternalImpl {
const OrtSessionOptions* session_options,
const OrtLogger* session_logger,
std::unique_ptr<IExecutionProvider>* ep) noexcept override;

OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override;
};
} // namespace onnxruntime

Expand Down
Loading