From 46e45590e1a93498c3fdda9d2c57ff25d389fa55 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 30 Oct 2025 20:17:43 +0800 Subject: [PATCH 01/22] Implement Split-K on Conv|MatMul This patch implements the `Split-K` optimization on `Conv|MatMul`. With `Split-K` we can re-arrange the computation into multiple workgroups when `K` is large to increase the parallelism. --- .../core/providers/webgpu/math/gemm_utils.cc | 125 ++++++++++++++++-- .../core/providers/webgpu/math/gemm_utils.h | 4 +- .../core/providers/webgpu/math/matmul.cc | 118 ++++++++++++++++- .../core/providers/webgpu/math/matmul.h | 29 ++++ .../providers/webgpu/math/matmul_packed.cc | 58 +++++++- .../providers/webgpu/math/matmul_packed.h | 42 +++++- onnxruntime/core/providers/webgpu/nn/conv.cc | 11 +- .../core/providers/webgpu/shader_helper.cc | 13 +- .../core/providers/webgpu/shader_variable.h | 1 + .../test/providers/cpu/nn/conv_fp16_test.cc | 118 ++++++++++++++++- .../test/providers/cpu/nn/conv_op_test.cc | 103 ++++++++++++++- 11 files changed, 589 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7dd3b50c656f4..f0d2fbfd134b0 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, << output.SetByIndices("coords", "value") << "\n"; } +void HandleMatMulWithSplitK( + ShaderHelper& shader, + ProgramVariableDataType output_variable_type) { + shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + + // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups, + // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support + // atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16` + // with `atomicLoad` and `atomicCompareExchangeWeak`: + // 1. Get `old_output_u32` from `output[offset]` with `atomicLoad`. + // 2. Convert `old_output_u32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`). + // 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`. + // 4. Convert the result of step 3 into `u32` values. + // 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak` + // and `old_output_u32`. The assignment will fail if at this time `output[offset]` is not + // equal to `old_output_u32` (it is updated in another invocation). If the assignment fails + // we have to go to step 1 and repeat all the above steps. + switch (output_variable_type) { + case ProgramVariableDataType::Float32x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 4u; + for (var i = 0u; i < 4u; i++) { + let offset = offset0 + i; + while (true) { + let old_output_u32 = atomicLoad(&output[offset]); + let old_output_f32 = bitcast(old_output_u32); + let new_output_f32 = old_output_f32 + value[i]; + let new_output_u32 = bitcast(new_output_f32); + let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32); + if (output_compexchange.old_value == old_output_u32) { + break; + } + } + } +)"; + break; + } + case ProgramVariableDataType::Float16x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 2u; + var vec2h_values : array; + vec2h_values[0] = value.xy; + vec2h_values[1] = value.zw; + for (var i = 0u; i < 2u; i++) { + let offset= offset0 + i; + while(true) { + let old_output_u32 = atomicLoad(&output[offset]); + let old_output_vec2h = bitcast(old_output_u32); + let new_output_vec2h = old_output_vec2h + vec2h_values[i]; + let new_output_u32 = bitcast(new_output_vec2h); + let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32); + if (output_compexchange.old_value == old_output_u32) { + break; + } + } + } +)"; + break; + } + default: + break; + } +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet, - bool is_channels_last) { + bool is_channels_last, + bool use_split_k, + ProgramVariableDataType output_variable_type) { shader.AdditionalImplementation() << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n"; @@ -134,7 +200,14 @@ void MatMulWriteFnSource(ShaderHelper& shader, shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n" << " var value = valueIn; \n"; - if (is_gemm) { + if (use_split_k) { + // Set output when MatMul is performed with Split-K. + // When split-k is used in MatMul, the bias will be handled in `MatMulFillBiasBeforeSplitKProgram` + // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we + // still need to handle `has_bias` and `is_channels_last` in `MatMulFillBiasBeforeSplitKProgram`. + assert(!has_bias); + HandleMatMulWithSplitK(shader, output_variable_type); + } else if (is_gemm) { HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); } else { HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last); @@ -159,9 +232,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, uint32_t tile_inner, bool split_k, uint32_t split_dim_inner) { - ORT_UNUSED_PARAMETER(split_k); - ORT_UNUSED_PARAMETER(split_dim_inner); - const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); std::string write_data_to_sub_a_vec4_snippet = @@ -208,14 +278,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let tileCol = i32(local_id.x);\n" << " let globalRow = i32(global_id.y) * rowPerThread;\n" << " let globalCol = i32(global_id.x);\n" - << " let batch = i32(global_id.z);\n" - << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" - << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" - << " var kStart = 0;\n" << " var acc: array, rowPerThread>;\n"; + if (split_k) { + // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) into multiple + // ones, and in the current workgroup we only compute `kSplitK` elements starting from + // `kSplitK * i32(global_id.z)`. + // + // For example: considering computing Y = (X * W + B) in one workgroup. + // Let kSplitk = 2, B = [d1, d2] + // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ] W = [[a2 a2] = [ A2 + // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 + // [b2 b2] C2 ] + // [b2 b2] + // [c2 c2] + // [c2 c2]] + // + // With Split-K: + // 1. Initialize output Y with B in `MatMulFillBiasBeforeSplitKProgram`: Y = [[d1, d2] + // [d1, d2]] + // 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side) + // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) + // Workgroup3: compute (C1 * C2) + // In each workgroup: + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - When the computation in each workgroup is completed, add the result to Y with several + // atomic built-in functions in `HandleMatMulWithSplitK()`. + shader.MainFunctionBody() + << "const kSplitK = " << split_dim_inner << ";\n" + << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" + << " var kStart = kSplitK * i32(global_id.z);\n" + + // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // the index of split-k instead of batch. + << " let batch = 0;\n" + << " let batchIndices = 0u;\n"; + } else { + shader.MainFunctionBody() + << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" + << " var kStart = 0;\n" + << " let batch = i32(global_id.z);\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); + } + // Loop over shared dimension. shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index ed4cf997d2f00..7075debeb9952 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet = "", - bool is_channels_last = false); + bool is_channels_last = false, + bool use_split_k = false, + ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4); // The two following functions are used to generate shader code for vec4 and scalar. // It is used in GEMM, Matmul, and Conv. diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index cf4b9d3fae2d2..fe669d63aa2d2 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -46,6 +46,55 @@ static std::string CalcResult(int64_t components, int64_t a_components, int64_t return oss.str(); } +SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { + const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); + SplitKConfig config = {}; + + if (adapter_info.vendor == std::string_view{"intel"}) { + if (adapter_info.architecture == std::string_view{"xe-lpg"} || + adapter_info.architecture == std::string_view{"xe-2lpg"} || + adapter_info.architecture == std::string_view{"xe-2hpg"}) { + config.enable_split_k_ = true; + + // Below thresholds are only verified on the above Intel GPUs. + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; + config.max_dim_a_outer_with_split_k_ = 196; + config.max_dim_b_outer_with_split_k_ = 768; + } + } + return config; +} + +bool SplitKConfig::UseSplitK( + bool is_vec4, + ActivationKind activation_kind, + uint64_t batch_size, + bool is_channels_last, + uint32_t dim_a_outer, + uint32_t dim_b_outer, + uint32_t dim_inner) const { + bool use_split_k = enable_split_k_; + + // TODO: support the cases below. + use_split_k &= activation_kind == ActivationKind::None; + use_split_k &= is_vec4; + use_split_k &= batch_size == 1; + use_split_k &= is_channels_last; + + // Split-K works best when `dim_inner` is large and both `a_outer` and `b_outer` are relatively small. + use_split_k &= + dim_a_outer <= max_dim_a_outer_with_split_k_ && + dim_b_outer <= max_dim_b_outer_with_split_k_ && + dim_inner >= min_dim_inner_with_split_k_; + + return use_split_k; +} + +uint32_t SplitKConfig::GetSplitDimInner() const { + return split_dim_inner_; +} + Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -167,6 +216,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { } MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + SplitKConfig split_k_config, const TensorShape& input_a_reshape, const TensorShape& input_b_reshape) { const auto* a = inputs[0]; @@ -222,21 +272,34 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector({4, 1, 1}) : InlinedVector({4, 4, 1}); + bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + + // When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram` + // instead of `MatMulProgram`. + if (need_split_k) { + has_bias = false; + } + const uint32_t dispatch_x = narrow((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); const uint32_t dispatch_y = narrow((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); - const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / - (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + uint32_t split_dim_inner = 1; + if (need_split_k) { + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; + } const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last}; + MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; program - .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last) + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) @@ -254,5 +317,52 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector(a_shape[a_shape.NumDimensions() - 2]); + const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); + + const uint32_t output_row = dim_a_outer; + const uint32_t output_col = dim_b_outer / bias_components; + constexpr uint32_t workgroup_size_x = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_X; + constexpr uint32_t workgroup_size_y = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_Y; + constexpr uint32_t elements_per_thread = MatMulFillBiasBeforeSplitKProgram::ELEMENTS_PER_THREAD; + const uint32_t dispatch_x = (output_col + workgroup_size_x * elements_per_thread - 1) / (workgroup_size_x * elements_per_thread); + const uint32_t dispatch_y = (output_row + workgroup_size_y - 1) / workgroup_size_y; + + // TODO: support batch_size > 1 + constexpr uint32_t batch_size = 1; + TensorShape output_shape_temp = TensorShape({batch_size, output_row, output_col}); + + const uint32_t data_type = output->GetElementType(); + program.CacheHint(has_bias, is_channels_last, data_type) + .AddOutput({output, ProgramTensorMetadataDependency::Rank, output_shape_temp, static_cast(bias_components)}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) + .SetWorkgroupSize(workgroup_size_x, workgroup_size_y, 1); + + if (has_bias) { + const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)}); + } + + return program; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 8ab8c3a6ba2d0..632024f729974 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,10 +14,39 @@ namespace onnxruntime { namespace webgpu { +class SplitKConfig { + public: + SplitKConfig() = default; + + static SplitKConfig GetSplitKConfig(const ComputeContext& context); + + bool UseSplitK( + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_channels_last, uint32_t dim_a_outer, + uint32_t dim_b_outer, uint32_t dim_inner) const; + + uint32_t GetSplitDimInner() const; + + private: + bool enable_split_k_ = false; + uint32_t split_dim_inner_ = 0; + uint32_t min_dim_inner_with_split_k_ = 0; + uint32_t max_dim_a_outer_with_split_k_ = 0; + uint32_t max_dim_b_outer_with_split_k_ = 0; +}; + MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + SplitKConfig split_k_config = SplitKConfig(), const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); +MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram( + const Tensor* bias, + Tensor* output, + bool is_channels_last, + const TensorShape& input_a_shape, + const TensorShape& input_b_shape); + class MatMul final : public WebGpuKernel { public: MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {} diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 585f8f1e011c4..cbebbb9e7e474 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -14,7 +14,15 @@ namespace webgpu { Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; + ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; + if (NeedSplitK()) { + // When Split-K is enabled, we should declare output as `atomic` to call atomic built-in functions on it. + output_usage |= ShaderUsage::UseAtomicU32ForSplitK | ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + } + const auto& output = shader.AddOutput("output", output_usage); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); if (has_bias_) { @@ -23,16 +31,60 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_); + MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, NeedSplitK(), output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source( + shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims, + /*transA = */ false, /*transB = */ false, /*alpha = */ 1.f, /*need_handle_matmul = */ true, + /*output_components = */ 4, /*tile_inner = */ 32, NeedSplitK(), split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } return Status::OK(); } +bool MatMulProgram::NeedSplitK() const { + return split_dim_inner_ > 1; +} + +Status MatMulFillBiasBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + // Handle bias with `MatMulWriteFnSource()`. + // Here `use_split_k` is false because we just initialize `output` with bias. + // The computation with Split-K will all be implemented in `MakeMatMulPackedVec4Source()`. + MatMulWriteFnSource( + shader, output, has_bias_, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + /*activation_snippet*/ "", is_channels_last_, /*use_split_k*/ false); + + shader.MainFunctionBody() << R"( + let output_components = 4;)"; + shader.MainFunctionBody() << R"( + let elements_per_thread = )" + << ELEMENTS_PER_THREAD + << ";\n"; + shader.MainFunctionBody() << R"( + let global_row = global_id.y; + if (global_row >= uniforms.dim_a_outer) { + return; + } + let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; + let batch = 0; + let row = i32(global_row); + let value = output_value_t(); + let start_col = i32(global_id.x) * elements_per_thread; + for (var col = start_col; col < start_col + elements_per_thread; col++) { + mm_write(batch, row, col, value); + })"; + + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 767fdd8802e5b..2ee3c13301ff4 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -13,24 +13,54 @@ namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"}, - activation_(activation), - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), - is_channels_last_(is_channels_last) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, {"dim_inner", ProgramUniformVariableDataType::Uint32}); + bool NeedSplitK() const; + private: const Activation activation_; const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; bool is_channels_last_ = false; + uint32_t split_dim_inner_ = 1; +}; + +// The program to initialize the output with 0 or bias before doing MatMul with Split-K. In Split-K, +// we set the output values with `atomicLoad` and `atomicCompareExchangeWeak` instead of a direct +// assignment (see the function `HandleMatMulWithSplitK()` in `gemm_utils.cc`), so we must initialize +// the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. +class MatMulFillBiasBeforeSplitKProgram final : public Program { + public: + explicit MatMulFillBiasBeforeSplitKProgram(bool has_bias, bool is_channels_last) + : Program{"MatMul_Fill_Bias_Before_Split_K"}, + has_bias_(has_bias), + is_channels_last_(is_channels_last) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}); + + constexpr static uint32_t WORKGROUP_SIZE_X = 8; + constexpr static uint32_t WORKGROUP_SIZE_Y = 8; + constexpr static uint32_t ELEMENTS_PER_THREAD = 8; + + private: + bool has_bias_ = false; + bool is_channels_last_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..6a2246864078f 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -200,7 +200,16 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + // Explicitly pass `SplitKConfig` to `CreateMatMulProgram()` to enable Split-K. Now it is not + // used in any other places that call `CreateMatMulProgram()` (e.g. in `MatMul::ComputeInternal()`). + // TODO: enable Split-K in all the places that call `CreateMatMulProgram()`. + SplitKConfig split_K_config = SplitKConfig::GetSplitKConfig(context); + MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, split_K_config, matmul_input_reshapes[0], matmul_input_reshapes[1]); + if (program.NeedSplitK()) { + MatMulFillBiasBeforeSplitKProgram fill_bias_program = CreateMatMulFillBiasBeforeSplitKProgram( + bias, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program)); + } return context.RunProgram(program); } } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 0182bdc607173..d6756feda7376 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -465,6 +465,11 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (size_t i = 0; i < output_vars_.size(); ++i) { const auto& output = output_vars_[i]; bool is_atomic = program_.Outputs()[i].is_atomic; + ProgramVariableDataType atomic_type = output->type_; + if (output->usage_ & ShaderUsage::UseAtomicU32ForSplitK) { + is_atomic = true; + atomic_type = ProgramVariableDataType::Uint32; + } uint32_t segments = output->segments_; for (uint32_t seg = 0; seg < segments; ++seg) { ss << "@group(0) @binding(" << binding_index++ << ") var "; @@ -475,14 +480,14 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } ss << ": array<"; if (is_atomic) { - if (output->type_ == ProgramVariableDataType::Float32) { + if (atomic_type == ProgramVariableDataType::Float32) { ss << "atomic"; // emulate float atomic via i32 - } else if (output->type_ == ProgramVariableDataType::Uint32) { + } else if (atomic_type == ProgramVariableDataType::Uint32) { ss << "atomic"; - } else if (output->type_ == ProgramVariableDataType::Int32) { + } else if (atomic_type == ProgramVariableDataType::Int32) { ss << "atomic"; } else { - ORT_RETURN_IF(true, "Unsupported atomic type: ", int(output->type_)); + ORT_RETURN_IF(true, "Unsupported atomic type: ", int(atomic_type)); } } else { ss << output->StorageType(); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 8e921d6deafbb..491ab9b9e193b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -72,6 +72,7 @@ struct ShaderUsage { UseGetByOffsetSegments = 4096, // use implementation of fn get_{name}_by_offset UseSetByOffsetSegments = 8192, // use implementation of fn set_{name}_by_offset UseUniform = 32768, // use uniform for shape and stride + UseAtomicU32ForSplitK = 65536, // use atomic for the output when using Split-K in MatMul } usage; ShaderUsage(decltype(usage) usage) : usage{usage} {} diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 8382258bf39b4..5bbd0d6bcaebb 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,9 +3,10 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) || defined(USE_WEBGPU) #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -36,10 +37,11 @@ If attributes.activation is set the NhwcFusedConv contrib op is used. If you are adding support for a new EP to the test and the EP does not support NhwcFusedConv please add the EP to the excluded_providers list. */ +template void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const T& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -424,6 +426,116 @@ TEST(ConvFp16Test, Conv2D_2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X_float32[x_index] * W_float32[w_index]; + } + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11); +} + +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B_float32(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X_float32[x_index] * W_float32[w_index]; + } + sum += B_float32[n]; + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + vector B = FloatsToMLFloat16s(B_float32); + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11); +} + TEST(ConvFp16Test, Conv2D_Bias_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1038,7 +1150,7 @@ TEST(ConvFp16Test, ConvDimWithZero) { vector W_shape = {2, 2, 1, 1}; vector out_shape = {0, 2, 4, 4}; - TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, vector(), out_shape); } TEST(ConvFp16Test, Conv1D_asymmetric_padding) { diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 7c84aefa1c01f..b5dc08560ccca 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/graph/constants.h" #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" using namespace std; @@ -20,10 +21,11 @@ struct ConvOpAndTestAttributes { std::unordered_set excluded_providers; }; +template void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const T& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, optional epsilon = optional(), @@ -535,6 +537,103 @@ TEST(ConvTest, Conv2D_AutoPad2) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + // Fill X and W + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + // Fill X, W and B + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + sum += B[n]; + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + // Conv10 TEST(ConvTest, Conv3D_1) { ConvOpAndTestAttributes attrs = { @@ -1037,7 +1136,7 @@ TEST(ConvTest, ConvDimWithZero) { // not handled by ACL attrs.excluded_providers.insert(kAclExecutionProvider); - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, optional(), + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, vector(), out_shape, false, optional(), OpTester::ExpectResult::kExpectSuccess, "", 10); } From 1f06b95f44b188af55117e3f04ac17da1858fbb7 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 31 Oct 2025 21:38:17 +0800 Subject: [PATCH 02/22] Address reviewer's comments --- .../core/providers/webgpu/math/gemm_utils.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index f0d2fbfd134b0..98149efb08ea1 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -283,18 +283,18 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << " var acc: array, rowPerThread>;\n"; if (split_k) { - // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) into multiple - // ones, and in the current workgroup we only compute `kSplitK` elements starting from + // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into + // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from // `kSplitK * i32(global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitk = 2, B = [d1, d2] - // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ] W = [[a2 a2] = [ A2 - // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 - // [b2 b2] C2 ] - // [b2 b2] - // [c2 c2] - // [c2 c2]] + // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2 + // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 + // [b2 b2] C2 ] + // [b2 b2] + // [c2 c2] + // [c2 c2]] // // With Split-K: // 1. Initialize output Y with B in `MatMulFillBiasBeforeSplitKProgram`: Y = [[d1, d2] From 31938158ce4b95b36db2bc85ea3c05a0f995462e Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Sat, 1 Nov 2025 17:28:11 +0800 Subject: [PATCH 03/22] Remove the check of `is_channels_last` in `UseSplitK` --- onnxruntime/core/providers/webgpu/math/matmul.cc | 4 +--- onnxruntime/core/providers/webgpu/math/matmul.h | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index fe669d63aa2d2..c17885314178f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -70,7 +70,6 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const { @@ -80,7 +79,6 @@ bool SplitKConfig::UseSplitK( use_split_k &= activation_kind == ActivationKind::None; use_split_k &= is_vec4; use_split_k &= batch_size == 1; - use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is large and both `a_outer` and `b_outer` are relatively small. use_split_k &= @@ -272,7 +270,7 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector({4, 1, 1}) : InlinedVector({4, 4, 1}); - bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner); // When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram` // instead of `MatMulProgram`. diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 632024f729974..9dc490e5f43ad 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -22,8 +22,7 @@ class SplitKConfig { bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool is_channels_last, uint32_t dim_a_outer, - uint32_t dim_b_outer, uint32_t dim_inner) const; + uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; uint32_t GetSplitDimInner() const; From ecbc0933b0cafcadabb3e52e617fee63d64a8826 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Sat, 1 Nov 2025 19:04:06 +0800 Subject: [PATCH 04/22] Still require `is_channels_last` to be true This reverts commit 31938158ce4b95b36db2bc85ea3c05a0f995462e. --- onnxruntime/core/providers/webgpu/math/gemm_utils.cc | 3 ++- onnxruntime/core/providers/webgpu/math/matmul.cc | 12 ++++++++---- onnxruntime/core/providers/webgpu/math/matmul.h | 4 ++-- .../core/providers/webgpu/math/matmul_packed.cc | 5 +++-- .../core/providers/webgpu/math/matmul_packed.h | 6 ++---- onnxruntime/core/providers/webgpu/nn/conv.cc | 4 +++- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 98149efb08ea1..909ec03cb711f 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -204,7 +204,8 @@ void MatMulWriteFnSource(ShaderHelper& shader, // Set output when MatMul is performed with Split-K. // When split-k is used in MatMul, the bias will be handled in `MatMulFillBiasBeforeSplitKProgram` // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we - // still need to handle `has_bias` and `is_channels_last` in `MatMulFillBiasBeforeSplitKProgram`. + // still need to handle `has_bias` (and `is_channels_last` in the future) in + // `MatMulFillBiasBeforeSplitKProgram`. assert(!has_bias); HandleMatMulWithSplitK(shader, output_variable_type); } else if (is_gemm) { diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index c17885314178f..ef6b0a3cff97e 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -70,6 +70,7 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const { @@ -79,6 +80,9 @@ bool SplitKConfig::UseSplitK( use_split_k &= activation_kind == ActivationKind::None; use_split_k &= is_vec4; use_split_k &= batch_size == 1; + // Now `is_channels_last` is only supported because we only generate vec4 shaders in + // `MatMulFillBiasBeforeSplitKProgram`. + use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is large and both `a_outer` and `b_outer` are relatively small. use_split_k &= @@ -270,7 +274,7 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector({4, 1, 1}) : InlinedVector({4, 4, 1}); - bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner); + bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); // When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram` // instead of `MatMulProgram`. @@ -318,13 +322,13 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vectorGetElementType(); - program.CacheHint(has_bias, is_channels_last, data_type) + program.CacheHint(has_bias, data_type) .AddOutput({output, ProgramTensorMetadataDependency::Rank, output_shape_temp, static_cast(bias_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 9dc490e5f43ad..ee97145af8d3b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -22,7 +22,8 @@ class SplitKConfig { bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; + bool is_channels_last, uint32_t dim_a_outer, + uint32_t dim_b_outer, uint32_t dim_inner) const; uint32_t GetSplitDimInner() const; @@ -42,7 +43,6 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector { // the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. class MatMulFillBiasBeforeSplitKProgram final : public Program { public: - explicit MatMulFillBiasBeforeSplitKProgram(bool has_bias, bool is_channels_last) + explicit MatMulFillBiasBeforeSplitKProgram(bool has_bias) : Program{"MatMul_Fill_Bias_Before_Split_K"}, - has_bias_(has_bias), - is_channels_last_(is_channels_last) { + has_bias_(has_bias) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -60,7 +59,6 @@ class MatMulFillBiasBeforeSplitKProgram final : public Program::ComputeInternal(ComputeContext& context SplitKConfig split_K_config = SplitKConfig::GetSplitKConfig(context); MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, split_K_config, matmul_input_reshapes[0], matmul_input_reshapes[1]); if (program.NeedSplitK()) { + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + assert(is_channels_last); MatMulFillBiasBeforeSplitKProgram fill_bias_program = CreateMatMulFillBiasBeforeSplitKProgram( - bias, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + bias, output, matmul_input_reshapes[0], matmul_input_reshapes[1]); ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program)); } return context.RunProgram(program); From 82d3d9b40d59a4c374e7d04e9661b8820ec46793 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 4 Nov 2025 21:14:46 +0800 Subject: [PATCH 05/22] Check the use of Split-K with ratio and enable Split-K on ACM --- .../core/providers/webgpu/math/matmul.cc | 20 +++++++++---------- .../core/providers/webgpu/math/matmul.h | 3 +-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index ef6b0a3cff97e..6273951b84243 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -51,16 +51,16 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { SplitKConfig config = {}; if (adapter_info.vendor == std::string_view{"intel"}) { - if (adapter_info.architecture == std::string_view{"xe-lpg"} || - adapter_info.architecture == std::string_view{"xe-2lpg"} || - adapter_info.architecture == std::string_view{"xe-2hpg"}) { + if (adapter_info.architecture == std::string_view{"xe-2lpg"} || + adapter_info.architecture == std::string_view{"xe-2hpg"} || + adapter_info.architecture == std::string_view{"xe-lpg"} || + adapter_info.architecture == std::string_view{"gen-12hp"}) { config.enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs. config.split_dim_inner_ = 256; config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; - config.max_dim_a_outer_with_split_k_ = 196; - config.max_dim_b_outer_with_split_k_ = 768; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.6f; } } return config; @@ -84,11 +84,11 @@ bool SplitKConfig::UseSplitK( // `MatMulFillBiasBeforeSplitKProgram`. use_split_k &= is_channels_last; - // Split-K works best when `dim_inner` is large and both `a_outer` and `b_outer` are relatively small. - use_split_k &= - dim_a_outer <= max_dim_a_outer_with_split_k_ && - dim_b_outer <= max_dim_b_outer_with_split_k_ && - dim_inner >= min_dim_inner_with_split_k_; + // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and + // `dim_b_outer`. Currently we use `(dim_a_outer * dim_b_outer * 1.0f / dim_inner)` as the + // metric to decide whether to use Split-K or not. + use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); + use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); return use_split_k; } diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index ee97145af8d3b..ee9caea84df24 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -31,8 +31,7 @@ class SplitKConfig { bool enable_split_k_ = false; uint32_t split_dim_inner_ = 0; uint32_t min_dim_inner_with_split_k_ = 0; - uint32_t max_dim_a_outer_with_split_k_ = 0; - uint32_t max_dim_b_outer_with_split_k_ = 0; + float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; }; MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, From 0099eddfb24b197a3a71f60ebd1b533e6cd6cf13 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 4 Nov 2025 21:38:27 +0800 Subject: [PATCH 06/22] Fix incorrect ratio --- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 6273951b84243..09be4ccfdcf3c 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -60,7 +60,7 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { // Below thresholds are only verified on the above Intel GPUs. config.split_dim_inner_ = 256; config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.6f; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 60.0f; } } return config; From 05bd1f84666bb78f16978824e208d101e3986144 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 5 Nov 2025 09:15:43 +0800 Subject: [PATCH 07/22] Update ratio --- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 09be4ccfdcf3c..767922e6a8736 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -60,7 +60,7 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { // Below thresholds are only verified on the above Intel GPUs. config.split_dim_inner_ = 256; config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 60.0f; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 48.0f; } } return config; From 11ecdfead6ee77245fc708cdc0ca740289e03a28 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 5 Nov 2025 09:26:23 +0800 Subject: [PATCH 08/22] Update ratio --- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 767922e6a8736..4d959a91d421c 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -60,7 +60,7 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { // Below thresholds are only verified on the above Intel GPUs. config.split_dim_inner_ = 256; config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 48.0f; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 54.0f; } } return config; From 534dc2c2196483f1a3bf459b470be7f84beb0645 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 6 Nov 2025 10:29:53 +0800 Subject: [PATCH 09/22] Compute FP16 values with MLFloat16 --- .../test/providers/cpu/nn/conv_fp16_test.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 5bbd0d6bcaebb..2daeaf96f24e1 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -450,6 +450,9 @@ TEST(ConvFp16Test, Conv2D_MatMul_SplitK_No_Bias) { vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + // Calculate expected output values vector expected_vals_float32; expected_vals_float32.resize(M * N); @@ -459,15 +462,12 @@ TEST(ConvFp16Test, Conv2D_MatMul_SplitK_No_Bias) { for (int k = 0; k < K; ++k) { int x_index = k * M + m; int w_index = n * K + k; - sum += X_float32[x_index] * W_float32[w_index]; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); } int y_index = n * M + m; expected_vals_float32[y_index] = sum; } } - - vector X = FloatsToMLFloat16s(X_float32); - vector W = FloatsToMLFloat16s(W_float32); vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, @@ -504,6 +504,10 @@ TEST(ConvFp16Test, Conv2D_MatMul_SplitK_With_Bias) { vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); vector B_float32(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + vector B = FloatsToMLFloat16s(B_float32); + // Calculate expected output values vector expected_vals_float32; expected_vals_float32.resize(M * N); @@ -513,17 +517,13 @@ TEST(ConvFp16Test, Conv2D_MatMul_SplitK_With_Bias) { for (int k = 0; k < K; ++k) { int x_index = k * M + m; int w_index = n * K + k; - sum += X_float32[x_index] * W_float32[w_index]; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); } - sum += B_float32[n]; + sum += B[n].ToFloat(); int y_index = n * M + m; expected_vals_float32[y_index] = sum; } } - - vector X = FloatsToMLFloat16s(X_float32); - vector W = FloatsToMLFloat16s(W_float32); - vector B = FloatsToMLFloat16s(B_float32); vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); TestConvFp16Op( From cfd22194913957718c5759e421008f0d9c42e806 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 7 Nov 2025 14:44:59 +0800 Subject: [PATCH 10/22] Address reviewer's comments - Renamed `MatMulFillBiasBeforeSplitKProgram` to `MatMulFillBiasOrZeroBeforeSplitKProgram` - Fill one vec4 value (0 or bias) per invocation in `MatMulFillBiasOrZeroBeforeSplitKProgram` - Renamed `CreateMatMulProgram()` to `ComputeMatMul()` and run both `MatMulProgram` and `MatMulFillBiasOrZeroBeforeSplitKProgram` in `ComputeMatMul()` - Removed `ShaderUsage::UseAtomicU32ForSplitK` and use `ProgramOutput::Atomic` instead - Removed `data_type` in the `CacheHint` of `MatMulFillBiasOrZeroBeforeSplitKProgram` - Updated the value of `config.split_dim_inner_` to 512 after more experiments --- .../core/providers/webgpu/math/gemm_utils.cc | 30 ++++---- .../core/providers/webgpu/math/matmul.cc | 71 ++++++++++++------- .../core/providers/webgpu/math/matmul.h | 10 +-- .../providers/webgpu/math/matmul_packed.cc | 45 ++++++------ .../providers/webgpu/math/matmul_packed.h | 10 ++- onnxruntime/core/providers/webgpu/nn/conv.cc | 16 ++--- onnxruntime/core/providers/webgpu/program.cc | 8 +++ onnxruntime/core/providers/webgpu/program.h | 1 + .../core/providers/webgpu/shader_helper.cc | 13 ++-- .../core/providers/webgpu/shader_variable.h | 1 - 10 files changed, 108 insertions(+), 97 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 909ec03cb711f..093a5a6327020 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -62,13 +62,13 @@ void HandleMatMulWithSplitK( // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support // atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16` // with `atomicLoad` and `atomicCompareExchangeWeak`: - // 1. Get `old_output_u32` from `output[offset]` with `atomicLoad`. - // 2. Convert `old_output_u32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`). + // 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`. + // 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`). // 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`. - // 4. Convert the result of step 3 into `u32` values. + // 4. Convert the result of step 3 into `i32` values. // 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak` - // and `old_output_u32`. The assignment will fail if at this time `output[offset]` is not - // equal to `old_output_u32` (it is updated in another invocation). If the assignment fails + // and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not + // equal to `old_output_i32` (it is updated in another invocation). If the assignment fails // we have to go to step 1 and repeat all the above steps. switch (output_variable_type) { case ProgramVariableDataType::Float32x4: { @@ -77,12 +77,12 @@ void HandleMatMulWithSplitK( for (var i = 0u; i < 4u; i++) { let offset = offset0 + i; while (true) { - let old_output_u32 = atomicLoad(&output[offset]); - let old_output_f32 = bitcast(old_output_u32); + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_f32 = bitcast(old_output_i32); let new_output_f32 = old_output_f32 + value[i]; - let new_output_u32 = bitcast(new_output_f32); - let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32); - if (output_compexchange.old_value == old_output_u32) { + let new_output_i32 = bitcast(new_output_f32); + let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compexchange.old_value == old_output_i32) { break; } } @@ -99,12 +99,12 @@ void HandleMatMulWithSplitK( for (var i = 0u; i < 2u; i++) { let offset= offset0 + i; while(true) { - let old_output_u32 = atomicLoad(&output[offset]); - let old_output_vec2h = bitcast(old_output_u32); + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_vec2h = bitcast(old_output_i32); let new_output_vec2h = old_output_vec2h + vec2h_values[i]; - let new_output_u32 = bitcast(new_output_vec2h); - let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32); - if (output_compexchange.old_value == old_output_u32) { + let new_output_i32 = bitcast(new_output_vec2h); + let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compexchange.old_value == old_output_i32) { break; } } diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 4d959a91d421c..85a5004d72c5c 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -59,7 +59,7 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { // Below thresholds are only verified on the above Intel GPUs. config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 54.0f; } } @@ -212,15 +212,15 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const auto* bias = context.Input(2); inputs.push_back(bias); } - auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false); - return context.RunProgram(program); + return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); } -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, - SplitKConfig split_k_config, - const TensorShape& input_a_reshape, - const TensorShape& input_b_reshape) { +Status ComputeMatMul(ComputeContext* context, + const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + SplitKConfig split_k_config, + const TensorShape& input_a_reshape, + const TensorShape& input_b_reshape) { const auto* a = inputs[0]; const auto* b = inputs[1]; bool has_bias = inputs.size() > 2; @@ -276,10 +276,12 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / @@ -299,27 +301,46 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vectorShape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); } - return program; + + if (need_split_k) { + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + assert(is_channels_last); + + MatMulFillBiasOrZeroBeforeSplitKProgram fill_bias_program = + CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, a_shape, b_shape); + ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); + } + + return context->RunProgram(matmul_program); } -MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram( +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, const TensorShape& input_a_shape, @@ -328,7 +349,7 @@ MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram( // Currently we only support bias in vec4 and channels last format for Split-K MatMul. constexpr uint32_t bias_components = 4; - MatMulFillBiasBeforeSplitKProgram program(has_bias); + MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); const TensorShape a_shape = input_a_shape; const TensorShape b_shape = input_b_shape; @@ -339,24 +360,22 @@ MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram( const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); + // Fill one value (currently only vec4) per invocation. const uint32_t output_row = dim_a_outer; const uint32_t output_col = dim_b_outer / bias_components; - constexpr uint32_t workgroup_size_x = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_X; - constexpr uint32_t workgroup_size_y = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_Y; - constexpr uint32_t elements_per_thread = MatMulFillBiasBeforeSplitKProgram::ELEMENTS_PER_THREAD; - const uint32_t dispatch_x = (output_col + workgroup_size_x * elements_per_thread - 1) / (workgroup_size_x * elements_per_thread); - const uint32_t dispatch_y = (output_row + workgroup_size_y - 1) / workgroup_size_y; + constexpr uint32_t workgroup_size_x = MatMulFillBiasOrZeroBeforeSplitKProgram::WORKGROUP_SIZE_X; + const uint32_t total_outputs = output_row * output_col; + const uint32_t dispatch_x = (total_outputs + workgroup_size_x - 1) / workgroup_size_x; // TODO: support batch_size > 1 constexpr uint32_t batch_size = 1; TensorShape output_shape_temp = TensorShape({batch_size, output_row, output_col}); - const uint32_t data_type = output->GetElementType(); - program.CacheHint(has_bias, data_type) - .AddOutput({output, ProgramTensorMetadataDependency::Rank, output_shape_temp, static_cast(bias_components)}) + program.CacheHint(has_bias) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_temp, static_cast(bias_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) - .SetWorkgroupSize(workgroup_size_x, workgroup_size_y, 1); + .SetDispatchGroupSize(dispatch_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x, 1, 1); if (has_bias) { const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index ee9caea84df24..8cf2ad02cd4f0 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -34,12 +34,12 @@ class SplitKConfig { float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; }; -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, - SplitKConfig split_k_config = SplitKConfig(), - const TensorShape& input_a_reshape = TensorShape(), - const TensorShape& input_b_reshape = TensorShape()); +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + SplitKConfig split_k_config = SplitKConfig(), + const TensorShape& input_a_reshape = TensorShape(), + const TensorShape& input_b_reshape = TensorShape()); -MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram( +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, const TensorShape& input_a_shape, diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 89746cb97a6dd..be388f313db27 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -15,11 +15,12 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const bool need_split_k = NeedSplitK(); ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; - ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; - if (NeedSplitK()) { - // When Split-K is enabled, we should declare output as `atomic` to call atomic built-in functions on it. - output_usage |= ShaderUsage::UseAtomicU32ForSplitK | ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + if (need_split_k) { + // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in + // functions on it, so we need below information to correctly compute the index on the output. + output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; } const auto& output = shader.AddOutput("output", output_usage); @@ -29,16 +30,17 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, NeedSplitK(), output_var_type); + MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source( shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims, /*transA = */ false, /*transB = */ false, /*alpha = */ 1.f, /*need_handle_matmul = */ true, - /*output_components = */ 4, /*tile_inner = */ 32, NeedSplitK(), split_dim_inner_)); + /*output_components = */ 4, /*tile_inner = */ 32, need_split_k, split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } @@ -49,7 +51,7 @@ bool MatMulProgram::NeedSplitK() const { return split_dim_inner_ > 1; } -Status MatMulFillBiasBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { +Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (has_bias_) { @@ -65,24 +67,21 @@ Status MatMulFillBiasBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shade /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); shader.MainFunctionBody() << R"( - let output_components = 4;)"; - shader.MainFunctionBody() << R"( - let elements_per_thread = )" - << ELEMENTS_PER_THREAD - << ";\n"; - shader.MainFunctionBody() << R"( - let global_row = global_id.y; - if (global_row >= uniforms.dim_a_outer) { + let output_components = 4; + let output_id = i32(global_id.x); + + let dim_a_outer = i32(uniforms.dim_a_outer); + let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; + let output_row = output_id / dim_b_outer; + if (output_row >= dim_a_outer) { return; } - let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; - let batch = 0; - let row = i32(global_row); - let value = output_value_t(); - let start_col = i32(global_id.x) * elements_per_thread; - for (var col = start_col; col < start_col + elements_per_thread; col++) { - mm_write(batch, row, col, value); - })"; + + let output_col = output_id % dim_b_outer; + let output_batch = 0; + let output_value = output_value_t(); + mm_write(output_batch, output_row, output_col, output_value); +)"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 759a3e08e1081..22fdd460f58e5 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -41,10 +41,10 @@ class MatMulProgram final : public Program { // we set the output values with `atomicLoad` and `atomicCompareExchangeWeak` instead of a direct // assignment (see the function `HandleMatMulWithSplitK()` in `gemm_utils.cc`), so we must initialize // the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. -class MatMulFillBiasBeforeSplitKProgram final : public Program { +class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { public: - explicit MatMulFillBiasBeforeSplitKProgram(bool has_bias) - : Program{"MatMul_Fill_Bias_Before_Split_K"}, + explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias) + : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, has_bias_(has_bias) { } @@ -53,9 +53,7 @@ class MatMulFillBiasBeforeSplitKProgram final : public Program::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - // Explicitly pass `SplitKConfig` to `CreateMatMulProgram()` to enable Split-K. Now it is not - // used in any other places that call `CreateMatMulProgram()` (e.g. in `MatMul::ComputeInternal()`). - // TODO: enable Split-K in all the places that call `CreateMatMulProgram()`. + // Explicitly pass `SplitKConfig` to `ComputeMatMul()` to enable Split-K. Now it is not + // used in any other places that call `ComputeMatMul()` (e.g. in `MatMul::ComputeInternal()`). + // TODO: enable Split-K in all the places that call `ComputeMatMul()`. SplitKConfig split_K_config = SplitKConfig::GetSplitKConfig(context); - MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, split_K_config, matmul_input_reshapes[0], matmul_input_reshapes[1]); - if (program.NeedSplitK()) { - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. - assert(is_channels_last); - MatMulFillBiasBeforeSplitKProgram fill_bias_program = CreateMatMulFillBiasBeforeSplitKProgram( - bias, output, matmul_input_reshapes[0], matmul_input_reshapes[1]); - ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program)); - } - return context.RunProgram(program); + return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, split_K_config, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 4a515cb988f1f..f4fabaabc5627 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -313,6 +313,14 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep use_override_shape{true}, override_shape{override_shape} {} +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component, ProgramOutput::AtomicTag) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + is_atomic{true}, + use_override_shape{true}, + override_shape{override_shape} {} + ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) : name_{name}, metadata_{metadata}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index c8f50837cd8e5..616688c40af0d 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -244,6 +244,7 @@ struct ProgramOutput { ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, AtomicTag); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component, AtomicTag); Tensor* tensor; uint32_t segments = 1; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index d6756feda7376..b048985973b23 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -465,11 +465,6 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (size_t i = 0; i < output_vars_.size(); ++i) { const auto& output = output_vars_[i]; bool is_atomic = program_.Outputs()[i].is_atomic; - ProgramVariableDataType atomic_type = output->type_; - if (output->usage_ & ShaderUsage::UseAtomicU32ForSplitK) { - is_atomic = true; - atomic_type = ProgramVariableDataType::Uint32; - } uint32_t segments = output->segments_; for (uint32_t seg = 0; seg < segments; ++seg) { ss << "@group(0) @binding(" << binding_index++ << ") var "; @@ -480,14 +475,14 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } ss << ": array<"; if (is_atomic) { - if (atomic_type == ProgramVariableDataType::Float32) { + if (output->type_ == ProgramVariableDataType::Float32 || output->type_ == ProgramVariableDataType::Float16x4 || output->type_ == ProgramVariableDataType::Float32x4) { ss << "atomic"; // emulate float atomic via i32 - } else if (atomic_type == ProgramVariableDataType::Uint32) { + } else if (output->type_ == ProgramVariableDataType::Uint32) { ss << "atomic"; - } else if (atomic_type == ProgramVariableDataType::Int32) { + } else if (output->type_ == ProgramVariableDataType::Int32) { ss << "atomic"; } else { - ORT_RETURN_IF(true, "Unsupported atomic type: ", int(atomic_type)); + ORT_RETURN_IF(true, "Unsupported atomic type: ", int(output->type_)); } } else { ss << output->StorageType(); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 491ab9b9e193b..8e921d6deafbb 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -72,7 +72,6 @@ struct ShaderUsage { UseGetByOffsetSegments = 4096, // use implementation of fn get_{name}_by_offset UseSetByOffsetSegments = 8192, // use implementation of fn set_{name}_by_offset UseUniform = 32768, // use uniform for shape and stride - UseAtomicU32ForSplitK = 65536, // use atomic for the output when using Split-K in MatMul } usage; ShaderUsage(decltype(usage) usage) : usage{usage} {} From d03755b52ab76956472a2523649ff5118ecc7097 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 10 Nov 2025 18:20:15 +0800 Subject: [PATCH 11/22] Disallow out-of-bound write --- onnxruntime/core/providers/webgpu/math/matmul_packed.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index be388f313db27..3741faa5ef346 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -72,11 +72,11 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& let dim_a_outer = i32(uniforms.dim_a_outer); let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; - let output_row = output_id / dim_b_outer; - if (output_row >= dim_a_outer) { + if (output_id >= dim_a_outer * dim_b_outer) { return; } + let output_row = output_id / dim_b_outer; let output_col = output_id % dim_b_outer; let output_batch = 0; let output_value = output_value_t(); From 581828e6d881a02975314e114abd0c9deb14e7e3 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 12 Nov 2025 13:05:31 +0800 Subject: [PATCH 12/22] Use safer thresholds by now --- onnxruntime/core/providers/webgpu/math/matmul.cc | 13 +++++++++---- onnxruntime/core/providers/webgpu/math/matmul.h | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 85a5004d72c5c..8800f4775b4dd 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -57,10 +57,14 @@ SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { adapter_info.architecture == std::string_view{"gen-12hp"}) { config.enable_split_k_ = true; - // Below thresholds are only verified on the above Intel GPUs. + // Below thresholds are only verified on the above Intel GPUs without any regressions. The + // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be + // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more + // atomic calls for each output value. config.split_dim_inner_ = 256; config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 54.0f; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } return config; @@ -85,9 +89,10 @@ bool SplitKConfig::UseSplitK( use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and - // `dim_b_outer`. Currently we use `(dim_a_outer * dim_b_outer * 1.0f / dim_inner)` as the - // metric to decide whether to use Split-K or not. + // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and + // `dim_inner)` as the metric to decide whether to use Split-K or not. use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); + use_split_k &= (dim_inner <= max_dim_inner_with_split_k_); use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); return use_split_k; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 8cf2ad02cd4f0..d76527c4e2453 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -31,6 +31,7 @@ class SplitKConfig { bool enable_split_k_ = false; uint32_t split_dim_inner_ = 0; uint32_t min_dim_inner_with_split_k_ = 0; + uint32_t max_dim_inner_with_split_k_ = 0; float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; }; From 22f9017e3fd61cbd60bdaefd1cb7d8899ba02a9d Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 13 Nov 2025 09:15:29 +0800 Subject: [PATCH 13/22] Address more reviewer's comments - Move the query of `SplitKConfig` into `ComputeMatMul()`. It's safe because in `MatMul::ComputeInternal()` `is_channels_last` is always false, while currently `Split-K` only supports `is_channels_last` being true. - Add a comment about avoiding the use of `global_id` or `global_idx` - Directly pass the temporary output shape in FillBiasOrZeroProgram - Merge multiple `if(needs_split_k)` into one in `ComputeMatMul()` - Use `global_idx` instead of `global_id.x` --- .../core/providers/webgpu/math/matmul.cc | 100 ++++++++---------- .../core/providers/webgpu/math/matmul.h | 6 +- .../providers/webgpu/math/matmul_packed.cc | 2 +- onnxruntime/core/providers/webgpu/nn/conv.cc | 6 +- onnxruntime/core/providers/webgpu/program.cc | 8 -- onnxruntime/core/providers/webgpu/program.h | 1 - 6 files changed, 47 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 8800f4775b4dd..60a3a323d9977 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -223,7 +223,6 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, - SplitKConfig split_k_config, const TensorShape& input_a_reshape, const TensorShape& input_b_reshape) { const auto* a = inputs[0]; @@ -279,33 +278,51 @@ Status ComputeMatMul(ComputeContext* context, ? InlinedVector({4, 1, 1}) : InlinedVector({4, 4, 1}); - bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); - - bool use_bias_in_matmul = has_bias; - - // When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram` - // instead of `MatMulProgram`. - if (need_split_k) { - use_bias_in_matmul = false; - } - const uint32_t dispatch_x = narrow((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); const uint32_t dispatch_y = narrow((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); - uint32_t split_dim_inner = 1; - if (need_split_k) { - split_dim_inner = split_k_config.GetSplitDimInner(); - dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; - } const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); + ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components); + const Tensor* bias = has_bias ? inputs[2] : nullptr; + bool use_bias_in_matmul = has_bias; + uint32_t split_dim_inner = 1; + + const SplitKConfig& split_k_config = SplitKConfig::GetSplitKConfig(*context); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + if (need_split_k) { + // Currently we only support `batch_size==1`, bias in vec4 and channels-last format for Split-K MatMul. + assert(batch_size == 1); + assert(is_vec4); + assert(is_channels_last); + + // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. + const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); + ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); + + // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set + // `bias` again in `MatMulProgram`. + use_bias_in_matmul = false; + + // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the + // number of splits along `dim_inner`. + // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize + // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; + + // The output should be declared in atomic types in `MatMulProgram` for the use of atomic + // built-in functions. + output.is_atomic = true; + } + MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; matmul_program .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) @@ -314,18 +331,8 @@ Status ComputeMatMul(ComputeContext* context, .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z); - - if (need_split_k) { - matmul_program.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components, ProgramOutput::Atomic}}); - } else { - matmul_program.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}); - } - - const Tensor* bias = nullptr; - if (has_bias) { - bias = inputs[2]; - } + .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) + .AddOutput(std::move(output)); if (use_bias_in_matmul) { auto bias_components = is_channels_last ? components : 1; @@ -333,51 +340,32 @@ Status ComputeMatMul(ComputeContext* context, matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); } - if (need_split_k) { - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. - assert(is_channels_last); - - MatMulFillBiasOrZeroBeforeSplitKProgram fill_bias_program = - CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, a_shape, b_shape); - ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); - } - return context->RunProgram(matmul_program); } MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& input_a_shape, - const TensorShape& input_b_shape) { + const TensorShape& output_shape_vec4) { const bool has_bias = bias != nullptr; // Currently we only support bias in vec4 and channels last format for Split-K MatMul. constexpr uint32_t bias_components = 4; MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); - const TensorShape a_shape = input_a_shape; - const TensorShape b_shape = input_b_shape; - MatMulComputeHelper helper; - ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape)); - TensorShape output_shape = helper.OutputShape(); - - const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); - const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); + const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); + const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); // Fill one value (currently only vec4) per invocation. - const uint32_t output_row = dim_a_outer; - const uint32_t output_col = dim_b_outer / bias_components; constexpr uint32_t workgroup_size_x = MatMulFillBiasOrZeroBeforeSplitKProgram::WORKGROUP_SIZE_X; - const uint32_t total_outputs = output_row * output_col; - const uint32_t dispatch_x = (total_outputs + workgroup_size_x - 1) / workgroup_size_x; - - // TODO: support batch_size > 1 - constexpr uint32_t batch_size = 1; - TensorShape output_shape_temp = TensorShape({batch_size, output_row, output_col}); + const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; + const uint32_t dispatch_x = (total_outputs_vec4 + workgroup_size_x - 1) / workgroup_size_x; + // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar + // instead of vec4, while use `output_shape_vec4` directly as the output shape. + const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components); program.CacheHint(has_bias) - .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_temp, static_cast(bias_components)}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) .SetDispatchGroupSize(dispatch_x, 1, 1) .SetWorkgroupSize(workgroup_size_x, 1, 1); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index d76527c4e2453..2c9fb33850c8b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -16,8 +16,6 @@ namespace webgpu { class SplitKConfig { public: - SplitKConfig() = default; - static SplitKConfig GetSplitKConfig(const ComputeContext& context); bool UseSplitK( @@ -36,15 +34,13 @@ class SplitKConfig { }; Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, - SplitKConfig split_k_config = SplitKConfig(), const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& input_a_shape, - const TensorShape& input_b_shape); + const TensorShape& output_shape_vec4); class MatMul final : public WebGpuKernel { public: diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 3741faa5ef346..4daabe8246aa7 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -68,7 +68,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader.MainFunctionBody() << R"( let output_components = 4; - let output_id = i32(global_id.x); + let output_id = i32(global_idx); let dim_a_outer = i32(uniforms.dim_a_outer); let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index f7ba93cc120e2..77fa46cb87518 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -200,11 +200,7 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - // Explicitly pass `SplitKConfig` to `ComputeMatMul()` to enable Split-K. Now it is not - // used in any other places that call `ComputeMatMul()` (e.g. in `MatMul::ComputeInternal()`). - // TODO: enable Split-K in all the places that call `ComputeMatMul()`. - SplitKConfig split_K_config = SplitKConfig::GetSplitKConfig(context); - return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, split_K_config, matmul_input_reshapes[0], matmul_input_reshapes[1]); + return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 55e7eb4d830a7..2c1b70222a5f6 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -313,14 +313,6 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep use_override_shape{true}, override_shape{override_shape} {} -ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component, ProgramOutput::AtomicTag) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - is_atomic{true}, - use_override_shape{true}, - override_shape{override_shape} {} - ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) : name_{name}, metadata_{metadata}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index c522b2fde00ef..80f6d831d0909 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -243,7 +243,6 @@ struct ProgramOutput { ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, AtomicTag); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component, AtomicTag); Tensor* tensor; ProgramTensorMetadataDependency dependency; From 418d6c0c518ae3370be8cbea9dfc373da635886e Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 11:01:06 +0800 Subject: [PATCH 14/22] Address comments from Copilot --- .../core/providers/webgpu/math/gemm_utils.cc | 26 +++++++++---------- .../core/providers/webgpu/math/matmul.cc | 7 +++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 093a5a6327020..ae61019fac82b 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -81,8 +81,8 @@ void HandleMatMulWithSplitK( let old_output_f32 = bitcast(old_output_i32); let new_output_f32 = old_output_f32 + value[i]; let new_output_i32 = bitcast(new_output_f32); - let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); - if (output_compexchange.old_value == old_output_i32) { + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { break; } } @@ -97,14 +97,14 @@ void HandleMatMulWithSplitK( vec2h_values[0] = value.xy; vec2h_values[1] = value.zw; for (var i = 0u; i < 2u; i++) { - let offset= offset0 + i; - while(true) { + let offset = offset0 + i; + while (true) { let old_output_i32 = atomicLoad(&output[offset]); let old_output_vec2h = bitcast(old_output_i32); let new_output_vec2h = old_output_vec2h + vec2h_values[i]; let new_output_i32 = bitcast(new_output_vec2h); - let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); - if (output_compexchange.old_value == old_output_i32) { + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { break; } } @@ -202,11 +202,11 @@ void MatMulWriteFnSource(ShaderHelper& shader, if (use_split_k) { // Set output when MatMul is performed with Split-K. - // When split-k is used in MatMul, the bias will be handled in `MatMulFillBiasBeforeSplitKProgram` + // When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we // still need to handle `has_bias` (and `is_channels_last` in the future) in - // `MatMulFillBiasBeforeSplitKProgram`. - assert(!has_bias); + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled."); HandleMatMulWithSplitK(shader, output_variable_type); } else if (is_gemm) { HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); @@ -289,7 +289,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // `kSplitK * i32(global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. - // Let kSplitk = 2, B = [d1, d2] + // Let kSplitK = 2, B = [d1, d2] // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2 // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 // [b2 b2] C2 ] @@ -298,8 +298,8 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // [c2 c2]] // // With Split-K: - // 1. Initialize output Y with B in `MatMulFillBiasBeforeSplitKProgram`: Y = [[d1, d2] - // [d1, d2]] + // 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2] + // [d1, d2]] // 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side) // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) @@ -310,7 +310,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(global_id.z);\n" // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate // the index of split-k instead of batch. diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 60a3a323d9977..83785de83c66b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -298,10 +298,9 @@ Status ComputeMatMul(ComputeContext* context, const SplitKConfig& split_k_config = SplitKConfig::GetSplitKConfig(*context); const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { - // Currently we only support `batch_size==1`, bias in vec4 and channels-last format for Split-K MatMul. - assert(batch_size == 1); - assert(is_vec4); - assert(is_channels_last); + ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); + ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); + ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); From 0ca5e656eec97aeac8a382d7e00cc7ce5264753b Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 11:16:17 +0800 Subject: [PATCH 15/22] Address more comments from Copilot --- onnxruntime/core/providers/webgpu/math/gemm_utils.cc | 1 + onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index ae61019fac82b..dbb648788231e 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -207,6 +207,7 @@ void MatMulWriteFnSource(ShaderHelper& shader, // still need to handle `has_bias` (and `is_channels_last` in the future) in // `MatMulFillBiasOrZeroBeforeSplitKProgram`. ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled."); + ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); HandleMatMulWithSplitK(shader, output_variable_type); } else if (is_gemm) { HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 83785de83c66b..bfe4b65d32dc8 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -85,7 +85,7 @@ bool SplitKConfig::UseSplitK( use_split_k &= is_vec4; use_split_k &= batch_size == 1; // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasBeforeSplitKProgram`. + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and From 7a415de7a1e83b2fd8995d75fd2131650ff20d1c Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 11:19:00 +0800 Subject: [PATCH 16/22] Address more comments from Copilot --- onnxruntime/core/providers/webgpu/math/gemm_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index dbb648788231e..7cbc7f6a4a821 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -13,7 +13,7 @@ namespace webgpu { // which are used in the MatMulWriteFnSource function. namespace { -void HanldeMaybeHaveBiasForGEMM(ShaderHelper& shader, +void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, const ShaderVariableHelper& output, bool has_bias, int c_components, @@ -210,7 +210,7 @@ void MatMulWriteFnSource(ShaderHelper& shader, ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); HandleMatMulWithSplitK(shader, output_variable_type); } else if (is_gemm) { - HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); + HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); } else { HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last); } From 13b94e86826a8940fd453b15793a79084fcd4cb8 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 13:25:28 +0800 Subject: [PATCH 17/22] Address reviewer's comments - Cache `SplitKConfig` to `WebGPUContext` and only initialize it once - Early return false when `enable_split_k_` is false - Use `WORKGROUP_SIZE` (in Program.h) instead of declaring another one --- .../core/providers/webgpu/compute_context.cc | 4 ++ .../core/providers/webgpu/compute_context.h | 8 +++ .../core/providers/webgpu/math/matmul.cc | 65 ++----------------- .../core/providers/webgpu/math/matmul.h | 19 ------ .../providers/webgpu/math/matmul_packed.h | 2 - .../core/providers/webgpu/nn/fuse_utils.cc | 1 + .../core/providers/webgpu/nn/fuse_utils.h | 8 ++- .../core/providers/webgpu/webgpu_context.cc | 7 ++ .../core/providers/webgpu/webgpu_context.h | 11 ++++ .../core/providers/webgpu/webgpu_utils.cc | 59 +++++++++++++++++ .../core/providers/webgpu/webgpu_utils.h | 21 ++++++ 11 files changed, 122 insertions(+), 83 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 25caa9b954fc0..514f00a66a080 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -34,5 +34,9 @@ const webgpu::BufferManager& ComputeContext::BufferManager() const { return ep_.BufferManager(); } +const SplitKConfig& ComputeContext::GetSplitKConfig() { + return webgpu_context_.GetSplitKConfig(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 6bf7df74ea861..a4613fc740d95 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -5,6 +5,7 @@ #include "core/providers/webgpu/webgpu_external_header.h" +#include #include #include "core/framework/execution_provider.h" @@ -146,6 +147,13 @@ class ComputeContext { // Status PopErrorScope(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + protected: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index bfe4b65d32dc8..7f2ec024dedde 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -46,62 +46,6 @@ static std::string CalcResult(int64_t components, int64_t a_components, int64_t return oss.str(); } -SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) { - const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); - SplitKConfig config = {}; - - if (adapter_info.vendor == std::string_view{"intel"}) { - if (adapter_info.architecture == std::string_view{"xe-2lpg"} || - adapter_info.architecture == std::string_view{"xe-2hpg"} || - adapter_info.architecture == std::string_view{"xe-lpg"} || - adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; - - // Below thresholds are only verified on the above Intel GPUs without any regressions. The - // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be - // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more - // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; - } - } - return config; -} - -bool SplitKConfig::UseSplitK( - bool is_vec4, - ActivationKind activation_kind, - uint64_t batch_size, - bool is_channels_last, - uint32_t dim_a_outer, - uint32_t dim_b_outer, - uint32_t dim_inner) const { - bool use_split_k = enable_split_k_; - - // TODO: support the cases below. - use_split_k &= activation_kind == ActivationKind::None; - use_split_k &= is_vec4; - use_split_k &= batch_size == 1; - // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasOrZeroBeforeSplitKProgram`. - use_split_k &= is_channels_last; - - // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and - // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and - // `dim_inner)` as the metric to decide whether to use Split-K or not. - use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); - use_split_k &= (dim_inner <= max_dim_inner_with_split_k_); - use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); - - return use_split_k; -} - -uint32_t SplitKConfig::GetSplitDimInner() const { - return split_dim_inner_; -} - Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -295,7 +239,7 @@ Status ComputeMatMul(ComputeContext* context, bool use_bias_in_matmul = has_bias; uint32_t split_dim_inner = 1; - const SplitKConfig& split_k_config = SplitKConfig::GetSplitKConfig(*context); + const SplitKConfig& split_k_config = context->GetSplitKConfig(); const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); @@ -356,9 +300,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); // Fill one value (currently only vec4) per invocation. - constexpr uint32_t workgroup_size_x = MatMulFillBiasOrZeroBeforeSplitKProgram::WORKGROUP_SIZE_X; const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; - const uint32_t dispatch_x = (total_outputs_vec4 + workgroup_size_x - 1) / workgroup_size_x; + const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar // instead of vec4, while use `output_shape_vec4` directly as the output shape. @@ -366,8 +309,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr program.CacheHint(has_bias) .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) - .SetDispatchGroupSize(dispatch_x, 1, 1) - .SetWorkgroupSize(workgroup_size_x, 1, 1); + .SetDispatchGroupSize(dispatch_x) + .SetWorkgroupSize(WORKGROUP_SIZE); if (has_bias) { const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 2c9fb33850c8b..0b65827be7f17 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,25 +14,6 @@ namespace onnxruntime { namespace webgpu { -class SplitKConfig { - public: - static SplitKConfig GetSplitKConfig(const ComputeContext& context); - - bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool is_channels_last, uint32_t dim_a_outer, - uint32_t dim_b_outer, uint32_t dim_inner) const; - - uint32_t GetSplitDimInner() const; - - private: - bool enable_split_k_ = false; - uint32_t split_dim_inner_ = 0; - uint32_t min_dim_inner_with_split_k_ = 0; - uint32_t max_dim_inner_with_split_k_ = 0; - float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; -}; - Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 22fdd460f58e5..143ba61c99e13 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -53,8 +53,6 @@ class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h index f5d2585bb9b45..fad7d3d145bc6 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h @@ -1,11 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include -#include "core/providers/webgpu/webgpu_kernel.h" +#include + +#include "core/common/status.h" #pragma once namespace onnxruntime { + +class OpKernelInfo; + namespace webgpu { + enum class ActivationKind { None, Relu, diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index a91e34c334687..29c95a08ef538 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -910,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 1ead7b3a005bb..bd7dae75f2e2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,12 +5,14 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" #include "core/common/common.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/webgpu_utils.h" #if defined(ENABLE_PIX_FOR_WEBGPU_EP) #include "core/providers/webgpu/webgpu_pix_frame_generator.h" @@ -171,6 +173,13 @@ class WebGpuContext final { Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: enum class TimestampQueryType { None = 0, @@ -268,6 +277,8 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; + std::optional split_k_config_; + // profiling TimestampQueryType query_type_; wgpu::QuerySet query_set_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 53b96dfe7a346..568d29a96cb88 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,5 +21,64 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } +SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { + SplitKConfig config = {}; + + if (adapter_info.vendor == std::string_view{"intel"}) { + if (adapter_info.architecture == std::string_view{"xe-2lpg"} || + adapter_info.architecture == std::string_view{"xe-2hpg"} || + adapter_info.architecture == std::string_view{"xe-lpg"} || + adapter_info.architecture == std::string_view{"gen-12hp"}) { + config.enable_split_k_ = true; + + // Below thresholds are only verified on the above Intel GPUs without any regressions. The + // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be + // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more + // atomic calls for each output value. + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + } + } + return config; +} + +bool SplitKConfig::UseSplitK( + bool is_vec4, + ActivationKind activation_kind, + uint64_t batch_size, + bool is_channels_last, + uint32_t dim_a_outer, + uint32_t dim_b_outer, + uint32_t dim_inner) const { + if (!enable_split_k_) { + return false; + } + + bool use_split_k = true; + + // TODO: support the cases below. + use_split_k &= activation_kind == ActivationKind::None; + use_split_k &= is_vec4; + use_split_k &= batch_size == 1; + // Now `is_channels_last` is only supported because we only generate vec4 shaders in + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + use_split_k &= is_channels_last; + + // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and + // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and + // `dim_inner)` as the metric to decide whether to use Split-K or not. + use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); + use_split_k &= (dim_inner <= max_dim_inner_with_split_k_); + use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); + + return use_split_k; +} + +uint32_t SplitKConfig::GetSplitDimInner() const { + return split_dim_inner_; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 86eb57f99f3b3..d45b9bf4dd119 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -7,6 +7,8 @@ #include "core/common/common.h" #include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" +#include "core/providers/webgpu/webgpu_external_header.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { @@ -89,5 +91,24 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +class SplitKConfig { + public: + static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + + bool UseSplitK( + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_channels_last, uint32_t dim_a_outer, + uint32_t dim_b_outer, uint32_t dim_inner) const; + + uint32_t GetSplitDimInner() const; + + private: + bool enable_split_k_ = false; + uint32_t split_dim_inner_ = 0; + uint32_t min_dim_inner_with_split_k_ = 0; + uint32_t max_dim_inner_with_split_k_ = 0; + float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; +}; + } // namespace webgpu } // namespace onnxruntime From 082d1e3169e7d58645d20be81b0e8f4805b02791 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 14:37:25 +0800 Subject: [PATCH 18/22] Don't call `SetWorkgroupSize()` as we are using the default value --- onnxruntime/core/providers/webgpu/math/matmul.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 7f2ec024dedde..55c2c5773cc1f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -299,7 +299,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); - // Fill one value (currently only vec4) per invocation. + // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for + // this program. const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; @@ -309,8 +310,7 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr program.CacheHint(has_bias) .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) - .SetDispatchGroupSize(dispatch_x) - .SetWorkgroupSize(WORKGROUP_SIZE); + .SetDispatchGroupSize(dispatch_x); if (has_bias) { const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); From 6b15ede1cabfb2dbe53c1c9204a39b7cd1153a0e Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 14:41:20 +0800 Subject: [PATCH 19/22] Remove a redundant declaration --- onnxruntime/core/providers/webgpu/compute_context.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index a4613fc740d95..69808b322c742 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -5,7 +5,6 @@ #include "core/providers/webgpu/webgpu_external_header.h" -#include #include #include "core/framework/execution_provider.h" From fb4c7430fdd3ac595fa4c92bd84dbf94ec8626a4 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 15:18:11 +0800 Subject: [PATCH 20/22] Address comments from Copilot --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 5 ++--- onnxruntime/test/providers/cpu/nn/conv_op_test.cc | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 2daeaf96f24e1..af0f0b374192b 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -37,11 +37,10 @@ If attributes.activation is set the NhwcFusedConv contrib op is used. If you are adding support for a new EP to the test and the EP does not support NhwcFusedConv please add the EP to the excluded_providers list. */ -template void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const T& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -1150,7 +1149,7 @@ TEST(ConvFp16Test, ConvDimWithZero) { vector W_shape = {2, 2, 1, 1}; vector out_shape = {0, 2, 4, 4}; - TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, vector(), out_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape); } TEST(ConvFp16Test, Conv1D_asymmetric_padding) { diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index b5dc08560ccca..4efbb8cfd5c19 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -21,11 +21,10 @@ struct ConvOpAndTestAttributes { std::unordered_set excluded_providers; }; -template void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const T& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, optional epsilon = optional(), @@ -1136,7 +1135,7 @@ TEST(ConvTest, ConvDimWithZero) { // not handled by ACL attrs.excluded_providers.insert(kAclExecutionProvider); - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, vector(), out_shape, false, optional(), + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, {}, out_shape, false, optional(), OpTester::ExpectResult::kExpectSuccess, "", 10); } From 4ef3ac2d5c78543864a1f5dee990e0fa43bdf6bf Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 15:25:10 +0800 Subject: [PATCH 21/22] Remove unused declarations --- .../core/providers/webgpu/compute_context.h | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 77498796ab72c..57800e0687cc1 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -139,25 +139,6 @@ class ComputeContext final { return webgpu_context_.Run(*this, program); } - // - // Get the buffer manager from the GPU allocator. - // - const webgpu::BufferManager& BufferManager() const; - - // - // Push error scope. - // - // This is useful only when "skip_validation" is not set. - // - void PushErrorScope(); - - // - // Pop error scope. - // - // This is useful only when "skip_validation" is not set. - // - Status PopErrorScope(); - // // Get Split-K configuration. // From fa6f22635446eb583c985218a6cf1c92cfd6cb26 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 14 Nov 2025 15:26:41 +0800 Subject: [PATCH 22/22] Fix another typo --- onnxruntime/core/providers/webgpu/compute_context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 57800e0687cc1..0aeefeaca492f 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -146,7 +146,7 @@ class ComputeContext final { // const SplitKConfig& GetSplitKConfig(); - protected: + private: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; const WebGpuExecutionProvider& ep_;