Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
46e4559
Implement Split-K on Conv|MatMul
Jiawei-Shao Oct 30, 2025
1f06b95
Address reviewer's comments
Jiawei-Shao Oct 31, 2025
3193815
Remove the check of `is_channels_last` in `UseSplitK`
Jiawei-Shao Nov 1, 2025
ecbc093
Still require `is_channels_last` to be true
Jiawei-Shao Nov 1, 2025
82d3d9b
Check the use of Split-K with ratio and enable Split-K on ACM
Jiawei-Shao Nov 4, 2025
0099edd
Fix incorrect ratio
Jiawei-Shao Nov 4, 2025
05bd1f8
Update ratio
Jiawei-Shao Nov 5, 2025
11ecdfe
Update ratio
Jiawei-Shao Nov 5, 2025
534dc2c
Compute FP16 values with MLFloat16
Jiawei-Shao Nov 6, 2025
cfd2219
Address reviewer's comments
Jiawei-Shao Nov 7, 2025
d03755b
Disallow out-of-bound write
Jiawei-Shao Nov 10, 2025
581828e
Use safer thresholds by now
Jiawei-Shao Nov 12, 2025
2ed25de
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 12, 2025
22f9017
Address more reviewer's comments
Jiawei-Shao Nov 13, 2025
418d6c0
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
0ca5e65
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
7a415de
Address more comments from Copilot
Jiawei-Shao Nov 14, 2025
13b94e8
Address reviewer's comments
Jiawei-Shao Nov 14, 2025
082d1e3
Don't call `SetWorkgroupSize()` as we are using the default value
Jiawei-Shao Nov 14, 2025
6b15ede
Remove a redundant declaration
Jiawei-Shao Nov 14, 2025
04e2890
Merge branch 'main' into impl-splitk-matmul
Jiawei-Shao Nov 14, 2025
fb4c743
Address comments from Copilot
Jiawei-Shao Nov 14, 2025
4ef3ac2
Remove unused declarations
Jiawei-Shao Nov 14, 2025
fa6f226
Fix another typo
Jiawei-Shao Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 116 additions & 9 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
<< output.SetByIndices("coords", "value") << "\n";
}

void HandleMatMulWithSplitK(
ShaderHelper& shader,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";

// With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
// so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
// atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16`
// with `atomicLoad` and `atomicCompareExchangeWeak`:
// 1. Get `old_output_u32` from `output[offset]` with `atomicLoad`.
// 2. Convert `old_output_u32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`).
// 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`.
// 4. Convert the result of step 3 into `u32` values.
// 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak`
// and `old_output_u32`. The assignment will fail if at this time `output[offset]` is not
// equal to `old_output_u32` (it is updated in another invocation). If the assignment fails
// we have to go to step 1 and repeat all the above steps.
switch (output_variable_type) {
case ProgramVariableDataType::Float32x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 4u;
for (var i = 0u; i < 4u; i++) {
let offset = offset0 + i;
while (true) {
let old_output_u32 = atomicLoad(&output[offset]);
let old_output_f32 = bitcast<f32>(old_output_u32);
let new_output_f32 = old_output_f32 + value[i];
let new_output_u32 = bitcast<u32>(new_output_f32);
let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32);
if (output_compexchange.old_value == old_output_u32) {
break;
}
}
}
)";
break;
}
case ProgramVariableDataType::Float16x4: {
shader.AdditionalImplementation() << R"(
let offset0 = i2o_output(coords) * 2u;
var vec2h_values : array<vec2h, 2>;
vec2h_values[0] = value.xy;
vec2h_values[1] = value.zw;
for (var i = 0u; i < 2u; i++) {
let offset= offset0 + i;
while(true) {
let old_output_u32 = atomicLoad(&output[offset]);
let old_output_vec2h = bitcast<vec2h>(old_output_u32);
let new_output_vec2h = old_output_vec2h + vec2h_values[i];
let new_output_u32 = bitcast<u32>(new_output_vec2h);
let output_compexchange = atomicCompareExchangeWeak(&output[offset], old_output_u32, new_output_u32);
if (output_compexchange.old_value == old_output_u32) {
break;
}
}
}
)";
break;
}
default:
break;
}
}

} // namespace

void MatMulReadFnSource(ShaderHelper& shader,
Expand Down Expand Up @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet,
bool is_channels_last) {
bool is_channels_last,
bool use_split_k,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";

Expand All @@ -134,7 +200,14 @@ void MatMulWriteFnSource(ShaderHelper& shader,
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
<< " var value = valueIn; \n";

if (is_gemm) {
if (use_split_k) {
// Set output when MatMul is performed with Split-K.
// When split-k is used in MatMul, the bias will be handled in `MatMulFillBiasBeforeSplitKProgram`
// instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we
// still need to handle `has_bias` and `is_channels_last` in `MatMulFillBiasBeforeSplitKProgram`.
assert(!has_bias);
HandleMatMulWithSplitK(shader, output_variable_type);
} else if (is_gemm) {
HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
} else {
HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last);
Expand All @@ -159,9 +232,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
uint32_t tile_inner,
bool split_k,
uint32_t split_dim_inner) {
ORT_UNUSED_PARAMETER(split_k);
ORT_UNUSED_PARAMETER(split_dim_inner);

const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type);

std::string write_data_to_sub_a_vec4_snippet =
Expand Down Expand Up @@ -208,14 +278,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " let tileCol = i32(local_id.x);\n"
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";

if (split_k) {
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) into multiple
// ones, and in the current workgroup we only compute `kSplitK` elements starting from
// `kSplitK * i32(global_id.z)`.
//
// For example: considering computing Y = (X * W + B) in one workgroup.
// Let kSplitk = 2, B = [d1, d2]
// Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ] W = [[a2 a2] = [ A2
// [a1 a1 b1 b1 c1 c1]] [a2 a2] B2
// [b2 b2] C2 ]
// [b2 b2]
// [c2 c2]
// [c2 c2]]
//
// With Split-K:
// 1. Initialize output Y with B in `MatMulFillBiasBeforeSplitKProgram`: Y = [[d1, d2]
// [d1, d2]]
// 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side)
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
// Workgroup3: compute (C1 * C2)
// In each workgroup:
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
// - When the computation in each workgroup is completed, add the result to Y with several
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " let batch = i32(global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
}

// Loop over shared dimension.
shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n";

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/math/gemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet = "",
bool is_channels_last = false);
bool is_channels_last = false,
bool use_split_k = false,
ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4);

// The two following functions are used to generate shader code for vec4 and scalar.
// It is used in GEMM, Matmul, and Conv.
Expand Down
118 changes: 114 additions & 4 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,55 @@ static std::string CalcResult(int64_t components, int64_t a_components, int64_t
return oss.str();
}

SplitKConfig SplitKConfig::GetSplitKConfig(const ComputeContext& context) {
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();
SplitKConfig config = {};

if (adapter_info.vendor == std::string_view{"intel"}) {
if (adapter_info.architecture == std::string_view{"xe-lpg"} ||
adapter_info.architecture == std::string_view{"xe-2lpg"} ||
adapter_info.architecture == std::string_view{"xe-2hpg"}) {
config.enable_split_k_ = true;

// Below thresholds are only verified on the above Intel GPUs.
config.split_dim_inner_ = 256;
config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2 + 1;
config.max_dim_a_outer_with_split_k_ = 196;
config.max_dim_b_outer_with_split_k_ = 768;
}
}
return config;
}

bool SplitKConfig::UseSplitK(
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
bool is_channels_last,
uint32_t dim_a_outer,
uint32_t dim_b_outer,
uint32_t dim_inner) const {
bool use_split_k = enable_split_k_;

// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;
use_split_k &= is_channels_last;

// Split-K works best when `dim_inner` is large and both `a_outer` and `b_outer` are relatively small.
use_split_k &=
dim_a_outer <= max_dim_a_outer_with_split_k_ &&
dim_b_outer <= max_dim_b_outer_with_split_k_ &&
dim_inner >= min_dim_inner_with_split_k_;

return use_split_k;
}

uint32_t SplitKConfig::GetSplitDimInner() const {
return split_dim_inner_;
}

Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -167,6 +216,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
}

MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
SplitKConfig split_k_config,
const TensorShape& input_a_reshape,
const TensorShape& input_b_reshape) {
const auto* a = inputs[0];
Expand Down Expand Up @@ -222,21 +272,34 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
? InlinedVector<int64_t>({4, 1, 1})
: InlinedVector<int64_t>({4, 4, 1});

bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);

// When Split-K is used, bias will be handled in `MatMulFillBiasBeforeSplitKProgram`
// instead of `MatMulProgram`.
if (need_split_k) {
has_bias = false;
}

const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
uint32_t split_dim_inner = 1;
if (need_split_k) {
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
}

const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});

MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
Expand All @@ -254,5 +317,52 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
return program;
}

MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
bool is_channels_last,
const TensorShape& input_a_shape,
const TensorShape& input_b_shape) {
const bool has_bias = bias != nullptr;

constexpr uint32_t bias_components = 4;
MatMulFillBiasBeforeSplitKProgram program(has_bias, is_channels_last);

const TensorShape a_shape = input_a_shape;
const TensorShape b_shape = input_b_shape;
MatMulComputeHelper helper;
ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape));
TensorShape output_shape = helper.OutputShape();

const uint32_t dim_a_outer = narrow<uint32_t>(a_shape[a_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(b_shape[b_shape.NumDimensions() - 1]);

const uint32_t output_row = dim_a_outer;
const uint32_t output_col = dim_b_outer / bias_components;
constexpr uint32_t workgroup_size_x = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_X;
constexpr uint32_t workgroup_size_y = MatMulFillBiasBeforeSplitKProgram::WORKGROUP_SIZE_Y;
constexpr uint32_t elements_per_thread = MatMulFillBiasBeforeSplitKProgram::ELEMENTS_PER_THREAD;
const uint32_t dispatch_x = (output_col + workgroup_size_x * elements_per_thread - 1) / (workgroup_size_x * elements_per_thread);
const uint32_t dispatch_y = (output_row + workgroup_size_y - 1) / workgroup_size_y;

// TODO: support batch_size > 1
constexpr uint32_t batch_size = 1;
TensorShape output_shape_temp = TensorShape({batch_size, output_row, output_col});

const uint32_t data_type = output->GetElementType();
program.CacheHint(has_bias, is_channels_last, data_type)
.AddOutput({output, ProgramTensorMetadataDependency::Rank, output_shape_temp, static_cast<int32_t>(bias_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1)
.SetWorkgroupSize(workgroup_size_x, workgroup_size_y, 1);

if (has_bias) {
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(bias_components)});
}

return program;
}

} // namespace webgpu
} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,39 @@
namespace onnxruntime {
namespace webgpu {

class SplitKConfig {
public:
SplitKConfig() = default;

static SplitKConfig GetSplitKConfig(const ComputeContext& context);

bool UseSplitK(
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size,
bool is_channels_last, uint32_t dim_a_outer,
uint32_t dim_b_outer, uint32_t dim_inner) const;

uint32_t GetSplitDimInner() const;

private:
bool enable_split_k_ = false;
uint32_t split_dim_inner_ = 0;
uint32_t min_dim_inner_with_split_k_ = 0;
uint32_t max_dim_a_outer_with_split_k_ = 0;
uint32_t max_dim_b_outer_with_split_k_ = 0;
};

MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last,
SplitKConfig split_k_config = SplitKConfig(),
const TensorShape& input_a_reshape = TensorShape(),
const TensorShape& input_b_reshape = TensorShape());

MatMulFillBiasBeforeSplitKProgram CreateMatMulFillBiasBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
bool is_channels_last,
const TensorShape& input_a_shape,
const TensorShape& input_b_shape);

class MatMul final : public WebGpuKernel {
public:
MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {}
Expand Down
Loading