Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
46e4559
Implement Split-K on Conv|MatMul
Jiawei-Shao Oct 30, 2025
1f06b95
Address reviewer's comments
Jiawei-Shao Oct 31, 2025
3193815
Remove the check of `is_channels_last` in `UseSplitK`
Jiawei-Shao Nov 1, 2025
ecbc093
Still require `is_channels_last` to be true
Jiawei-Shao Nov 1, 2025
82d3d9b
Check the use of Split-K with ratio and enable Split-K on ACM
Jiawei-Shao Nov 4, 2025
0099edd
Fix incorrect ratio
Jiawei-Shao Nov 4, 2025
05bd1f8
Update ratio
Jiawei-Shao Nov 5, 2025
11ecdfe
Update ratio
Jiawei-Shao Nov 5, 2025
534dc2c
Compute FP16 values with MLFloat16
Jiawei-Shao Nov 6, 2025
cfd2219
Address reviewer's comments
Jiawei-Shao Nov 7, 2025
d03755b
Disallow out-of-bound write
Jiawei-Shao Nov 10, 2025
581828e
Use safer thresholds by now
Jiawei-Shao Nov 12, 2025
2ed25de
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 12, 2025
22f9017
Address more reviewer's comments
Jiawei-Shao Nov 13, 2025
418d6c0
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
0ca5e65
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
7a415de
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
13b94e8
Address reviewer's comments
Jiawei-Shao Nov 14, 2025
082d1e3
Don't call `SetWorkgroupSize()` as we are using the default value
Jiawei-Shao Nov 14, 2025
6b15ede
Remove a redundant declaration
Jiawei-Shao Nov 14, 2025
04e2890
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 14, 2025
fb4c743
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
4ef3ac2
Remove unused declarations
Jiawei-Shao Nov 14, 2025
fa6f226
Fix another typo
Jiawei-Shao Nov 14, 2025
ec8d47d
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 17, 2025
d010d4b
Use a higher rel_error for Linux ARM64 bots
Jiawei-Shao Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 117 additions & 9 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
<< output.SetByIndices("coords", "value") << "\n";
}

void HandleMatMulWithSplitK(
ShaderHelper& shader,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";

// With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
// so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
// atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16`
// with `atomicLoad` and `atomicCompareExchangeWeak`:
// 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`.
// 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`).
// 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`.
// 4. Convert the result of step 3 into `i32` values.
// 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak`
// and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not
// equal to `old_output_i32` (it is updated in another invocation). If the assignment fails
// we have to go to step 1 and repeat all the above steps.
switch (output_variable_type) {
case ProgramVariableDataType::Float32x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 4u;
for (var i = 0u; i < 4u; i++) {
let offset = offset0 + i;
while (true) {
let old_output_i32 = atomicLoad(&output[offset]);
let old_output_f32 = bitcast<f32>(old_output_i32);
let new_output_f32 = old_output_f32 + value[i];
let new_output_i32 = bitcast<i32>(new_output_f32);
let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
if (output_compexchange.old_value == old_output_i32) {
break;
}
}
}
)";
break;
}
case ProgramVariableDataType::Float16x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 2u;
var vec2h_values : array<vec2h, 2>;
vec2h_values[0] = value.xy;
vec2h_values[1] = value.zw;
for (var i = 0u; i < 2u; i++) {
let offset= offset0 + i;
while(true) {
let old_output_i32 = atomicLoad(&output[offset]);
let old_output_vec2h = bitcast<vec2h>(old_output_i32);
let new_output_vec2h = old_output_vec2h + vec2h_values[i];
let new_output_i32 = bitcast<i32>(new_output_vec2h);
let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
if (output_compexchange.old_value == old_output_i32) {
break;
}
}
}
)";
break;
}
default:
break;
}
}

} // namespace

void MatMulReadFnSource(ShaderHelper& shader,
Expand Down Expand Up @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet,
bool is_channels_last) {
bool is_channels_last,
bool use_split_k,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";

Expand All @@ -134,7 +200,15 @@ void MatMulWriteFnSource(ShaderHelper& shader,
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
<< " var value = valueIn; \n";

if (is_gemm) {
if (use_split_k) {
// Set output when MatMul is performed with Split-K.
// When split-k is used in MatMul, the bias will be handled in `MatMulFillBiasBeforeSplitKProgram`
// instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we
// still need to handle `has_bias` (and `is_channels_last` in the future) in
// `MatMulFillBiasBeforeSplitKProgram`.
assert(!has_bias);
HandleMatMulWithSplitK(shader, output_variable_type);
} else if (is_gemm) {
HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
} else {
HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last);
Expand All @@ -159,9 +233,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
uint32_t tile_inner,
bool split_k,
uint32_t split_dim_inner) {
ORT_UNUSED_PARAMETER(split_k);
ORT_UNUSED_PARAMETER(split_dim_inner);

const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type);

std::string write_data_to_sub_a_vec4_snippet =
Expand Down Expand Up @@ -208,14 +279,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " let tileCol = i32(local_id.x);\n"
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";

if (split_k) {
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
// multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
// `kSplitK * i32(global_id.z)`.
//
// For example: considering computing Y = (X * W + B) in one workgroup.
// Let kSplitk = 2, B = [d1, d2]
// Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2
// [a1 a1 b1 b1 c1 c1]] [a2 a2] B2
// [b2 b2] C2 ]
// [b2 b2]
// [c2 c2]
// [c2 c2]]
//
// With Split-K:
// 1. Initialize output Y with B in `MatMulFillBiasBeforeSplitKProgram`: Y = [[d1, d2]
// [d1, d2]]
// 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side)
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
// Workgroup3: compute (C1 * C2)
// In each workgroup:
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
// - When the computation in each workgroup is completed, add the result to Y with several
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
}

// Loop over shared dimension.
shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n";

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/math/gemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet = "",
bool is_channels_last = false);
bool is_channels_last = false,
bool use_split_k = false,
ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4);

// The two following functions are used to generate shader code for vec4 and scalar.
// It is used in GEMM, Matmul, and Conv.
Expand Down
162 changes: 149 additions & 13 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,62 @@ static std::string CalcResult(int64_t components, int64_t a_components, int64_t
return oss.str();
}

SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) {
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();
SplitKConfig config = {};

if (adapter_info.vendor == std::string_view{"intel"}) {
if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
adapter_info.architecture == std::string_view{"xe-2hpg"} ||
adapter_info.architecture == std::string_view{"xe-lpg"} ||
adapter_info.architecture == std::string_view{"gen-12hp"}) {
config.enable_split_k_ = true;

// Below thresholds are only verified on the above Intel GPUs without any regressions. The
// proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be
// reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more
// atomic calls for each output value.
config.split_dim_inner_ = 256;
config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2;
config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9;
config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f;
}
}
return config;
}

bool SplitKConfig::UseSplitK(
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
bool is_channels_last,
uint32_t dim_a_outer,
uint32_t dim_b_outer,
uint32_t dim_inner) const {
bool use_split_k = enable_split_k_;

// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;
// Now `is_channels_last` is only supported because we only generate vec4 shaders in
// `MatMulFillBiasBeforeSplitKProgram`.
use_split_k &= is_channels_last;

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
// `dim_inner)` as the metric to decide whether to use Split-K or not.
use_split_k &= (dim_inner >= min_dim_inner_with_split_k_);
use_split_k &= (dim_inner <= max_dim_inner_with_split_k_);
use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_);

return use_split_k;
}

uint32_t SplitKConfig::GetSplitDimInner() const {
return split_dim_inner_;
}

Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -161,14 +217,15 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const auto* bias = context.Input(2);
inputs.push_back(bias);
}
auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);

return context.RunProgram(program);
return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
}

MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
const TensorShape& input_a_reshape,
const TensorShape& input_b_reshape) {
Status ComputeMatMul(ComputeContext* context,
const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
SplitKConfig split_k_config,
const TensorShape& input_a_reshape,
const TensorShape& input_b_reshape) {
const auto* a = inputs[0];
const auto* b = inputs[1];
bool has_bias = inputs.size() > 2;
Expand Down Expand Up @@ -222,35 +279,114 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
? InlinedVector<int64_t>({4, 1, 1})
: InlinedVector<int64_t>({4, 4, 1});

bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);

bool use_bias_in_matmul = has_bias;

// When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram`
// instead of `MatMulProgram`.
if (need_split_k) {
use_bias_in_matmul = false;
}

const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
uint32_t split_dim_inner = 1;
if (need_split_k) {
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
}

const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});

MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
matmul_program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);

if (need_split_k) {
matmul_program.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components, ProgramOutput::Atomic}});
} else {
matmul_program.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}});
}

const Tensor* bias = nullptr;
if (has_bias) {
bias = inputs[2];
}

if (use_bias_in_matmul) {
auto bias_components = is_channels_last ? components : 1;
const auto* bias = inputs[2];
TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
}

if (need_split_k) {
// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
assert(is_channels_last);

MatMulFillBiasOrZeroBeforeSplitKProgram fill_bias_program =
CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, a_shape, b_shape);
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));
}

return context->RunProgram(matmul_program);
}

MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
const TensorShape& input_a_shape,
const TensorShape& input_b_shape) {
const bool has_bias = bias != nullptr;

// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
constexpr uint32_t bias_components = 4;
MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias);

const TensorShape a_shape = input_a_shape;
const TensorShape b_shape = input_b_shape;
MatMulComputeHelper helper;
ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape));
TensorShape output_shape = helper.OutputShape();

const uint32_t dim_a_outer = narrow<uint32_t>(a_shape[a_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(b_shape[b_shape.NumDimensions() - 1]);

// Fill one value (currently only vec4) per invocation.
const uint32_t output_row = dim_a_outer;
const uint32_t output_col = dim_b_outer / bias_components;
constexpr uint32_t workgroup_size_x = MatMulFillBiasOrZeroBeforeSplitKProgram::WORKGROUP_SIZE_X;
const uint32_t total_outputs = output_row * output_col;
const uint32_t dispatch_x = (total_outputs + workgroup_size_x - 1) / workgroup_size_x;

// TODO: support batch_size > 1
constexpr uint32_t batch_size = 1;
TensorShape output_shape_temp = TensorShape({batch_size, output_row, output_col});

program.CacheHint(has_bias)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_temp, static_cast<int32_t>(bias_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
.SetDispatchGroupSize(dispatch_x, 1, 1)
.SetWorkgroupSize(workgroup_size_x, 1, 1);

if (has_bias) {
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(bias_components)});
}

return program;
}

Expand Down
Loading