Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
46e4559
Implement Split-K on Conv|MatMul
Jiawei-Shao Oct 30, 2025
1f06b95
Address reviewer's comments
Jiawei-Shao Oct 31, 2025
3193815
Remove the check of `is_channels_last` in `UseSplitK`
Jiawei-Shao Nov 1, 2025
ecbc093
Still require `is_channels_last` to be true
Jiawei-Shao Nov 1, 2025
82d3d9b
Check the use of Split-K with ratio and enable Split-K on ACM
Jiawei-Shao Nov 4, 2025
0099edd
Fix incorrect ratio
Jiawei-Shao Nov 4, 2025
05bd1f8
Update ratio
Jiawei-Shao Nov 5, 2025
11ecdfe
Update ratio
Jiawei-Shao Nov 5, 2025
534dc2c
Compute FP16 values with MLFloat16
Jiawei-Shao Nov 6, 2025
cfd2219
Address reviewer's comments
Jiawei-Shao Nov 7, 2025
d03755b
Disallow out-of-bound write
Jiawei-Shao Nov 10, 2025
581828e
Use safer thresholds by now
Jiawei-Shao Nov 12, 2025
2ed25de
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 12, 2025
22f9017
Address more reviewer's comments
Jiawei-Shao Nov 13, 2025
418d6c0
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
0ca5e65
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
7a415de
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
13b94e8
Address reviewer's comments
Jiawei-Shao Nov 14, 2025
082d1e3
Don't call `SetWorkgroupSize()` as we are using the default value
Jiawei-Shao Nov 14, 2025
6b15ede
Remove a redundant declaration
Jiawei-Shao Nov 14, 2025
04e2890
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 14, 2025
fb4c743
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
4ef3ac2
Remove unused declarations
Jiawei-Shao Nov 14, 2025
fa6f226
Fix another typo
Jiawei-Shao Nov 14, 2025
ec8d47d
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 17, 2025
d010d4b
Use a higher rel_error for Linux ARM64 bots
Jiawei-Shao Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 118 additions & 9 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
<< output.SetByIndices("coords", "value") << "\n";
}

void HandleMatMulWithSplitK(
ShaderHelper& shader,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";

// With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
// so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
// atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16`
// with `atomicLoad` and `atomicCompareExchangeWeak`:
// 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`.
// 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`).
// 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`.
// 4. Convert the result of step 3 into `i32` values.
// 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak`
// and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not
// equal to `old_output_i32` (it is updated in another invocation). If the assignment fails
// we have to go to step 1 and repeat all the above steps.
switch (output_variable_type) {
case ProgramVariableDataType::Float32x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 4u;
for (var i = 0u; i < 4u; i++) {
let offset = offset0 + i;
while (true) {
let old_output_i32 = atomicLoad(&output[offset]);
let old_output_f32 = bitcast<f32>(old_output_i32);
let new_output_f32 = old_output_f32 + value[i];
let new_output_i32 = bitcast<i32>(new_output_f32);
let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
if (output_compare_exchange.old_value == old_output_i32) {
break;
}
}
}
)";
break;
}
case ProgramVariableDataType::Float16x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 2u;
var vec2h_values : array<vec2h, 2>;
vec2h_values[0] = value.xy;
vec2h_values[1] = value.zw;
for (var i = 0u; i < 2u; i++) {
let offset = offset0 + i;
while (true) {
let old_output_i32 = atomicLoad(&output[offset]);
let old_output_vec2h = bitcast<vec2h>(old_output_i32);
let new_output_vec2h = old_output_vec2h + vec2h_values[i];
let new_output_i32 = bitcast<i32>(new_output_vec2h);
let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
if (output_compare_exchange.old_value == old_output_i32) {
break;
}
}
}
)";
break;
}
default:
break;
}
}

} // namespace

void MatMulReadFnSource(ShaderHelper& shader,
Expand Down Expand Up @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet,
bool is_channels_last) {
bool is_channels_last,
bool use_split_k,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";

Expand All @@ -134,7 +200,16 @@ void MatMulWriteFnSource(ShaderHelper& shader,
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
<< " var value = valueIn; \n";

if (is_gemm) {
if (use_split_k) {
// Set output when MatMul is performed with Split-K.
// When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
// instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we
// still need to handle `has_bias` (and `is_channels_last` in the future) in
// `MatMulFillBiasOrZeroBeforeSplitKProgram`.
ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled.");
ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled.");
HandleMatMulWithSplitK(shader, output_variable_type);
} else if (is_gemm) {
HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
} else {
HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last);
Expand All @@ -159,9 +234,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
uint32_t tile_inner,
bool split_k,
uint32_t split_dim_inner) {
ORT_UNUSED_PARAMETER(split_k);
ORT_UNUSED_PARAMETER(split_dim_inner);

const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type);

std::string write_data_to_sub_a_vec4_snippet =
Expand Down Expand Up @@ -208,14 +280,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " let tileCol = i32(local_id.x);\n"
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";

if (split_k) {
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
// multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
// `kSplitK * i32(global_id.z)`.
//
// For example: considering computing Y = (X * W + B) in one workgroup.
// Let kSplitK = 2, B = [d1, d2]
// Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2
// [a1 a1 b1 b1 c1 c1]] [a2 a2] B2
// [b2 b2] C2 ]
// [b2 b2]
// [c2 c2]
// [c2 c2]]
//
// With Split-K:
// 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2]
// [d1, d2]]
// 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side)
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
// Workgroup3: compute (C1 * C2)
// In each workgroup:
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
// - When the computation in each workgroup is completed, add the result to Y with several
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
}

// Loop over shared dimension.
shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n";

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/math/gemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet = "",
bool is_channels_last = false);
bool is_channels_last = false,
bool use_split_k = false,
ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4);

// The two following functions are used to generate shader code for vec4 and scalar.
// It is used in GEMM, Matmul, and Conv.
Expand Down
153 changes: 138 additions & 15 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,62 @@ static std::string CalcResult(int64_t components, int64_t a_components, int64_t
return oss.str();
}

SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) {
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();
SplitKConfig config = {};

if (adapter_info.vendor == std::string_view{"intel"}) {
if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
adapter_info.architecture == std::string_view{"xe-2hpg"} ||
adapter_info.architecture == std::string_view{"xe-lpg"} ||
adapter_info.architecture == std::string_view{"gen-12hp"}) {
config.enable_split_k_ = true;

// Below thresholds are only verified on the above Intel GPUs without any regressions. The
// proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be
// reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more
// atomic calls for each output value.
config.split_dim_inner_ = 256;
config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2;
config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9;
config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f;
}
}
return config;
}

bool SplitKConfig::UseSplitK(
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
bool is_channels_last,
uint32_t dim_a_outer,
uint32_t dim_b_outer,
uint32_t dim_inner) const {
bool use_split_k = enable_split_k_;

// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;
// Now `is_channels_last` is only supported because we only generate vec4 shaders in
// `MatMulFillBiasOrZeroBeforeSplitKProgram`.
use_split_k &= is_channels_last;

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
// `dim_inner)` as the metric to decide whether to use Split-K or not.
use_split_k &= (dim_inner >= min_dim_inner_with_split_k_);
use_split_k &= (dim_inner <= max_dim_inner_with_split_k_);
use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_);

return use_split_k;
}

uint32_t SplitKConfig::GetSplitDimInner() const {
return split_dim_inner_;
}

Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -161,14 +217,14 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const auto* bias = context.Input(2);
inputs.push_back(bias);
}
auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);

return context.RunProgram(program);
return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
}

MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
const TensorShape& input_a_reshape,
const TensorShape& input_b_reshape) {
Status ComputeMatMul(ComputeContext* context,
const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
const TensorShape& input_a_reshape,
const TensorShape& input_b_reshape) {
const auto* a = inputs[0];
const auto* b = inputs[1];
bool has_bias = inputs.size() > 2;
Expand Down Expand Up @@ -226,31 +282,98 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));

const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});

MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components);
const Tensor* bias = has_bias ? inputs[2] : nullptr;
bool use_bias_in_matmul = has_bias;
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = SplitKConfig::GetSplitKConfig(*context);
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp);
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
// TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize
// the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}

MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
matmul_program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
.AddOutput(std::move(output));

if (has_bias) {
if (use_bias_in_matmul) {
auto bias_components = is_channels_last ? components : 1;
const auto* bias = inputs[2];
TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
}

return context->RunProgram(matmul_program);
}

MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
const TensorShape& output_shape_vec4) {
const bool has_bias = bias != nullptr;

// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
constexpr uint32_t bias_components = 4;
MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias);

const uint32_t dim_a_outer = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]);
const uint32_t dim_b_outer_vec4 = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]);

// Fill one value (currently only vec4) per invocation.
constexpr uint32_t workgroup_size_x = MatMulFillBiasOrZeroBeforeSplitKProgram::WORKGROUP_SIZE_X;
const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4;
const uint32_t dispatch_x = (total_outputs_vec4 + workgroup_size_x - 1) / workgroup_size_x;

// To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar
// instead of vec4, while use `output_shape_vec4` directly as the output shape.
const uint32_t dim_b_outer = narrow<uint32_t>(dim_b_outer_vec4 * bias_components);
program.CacheHint(has_bias)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast<int32_t>(bias_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
.SetDispatchGroupSize(dispatch_x, 1, 1)
.SetWorkgroupSize(workgroup_size_x, 1, 1);

if (has_bias) {
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(bias_components)});
}

return program;
}

Expand Down
Loading