Skip to content

[webgpu] Add zero points support for dp4 path #24675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 163 additions & 33 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Large diffs are not rendered by default.

23 changes: 17 additions & 6 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,32 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram

class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
public:
DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits) : Program{"DP4AMatMulNBits"}, block_size_(block_size), nbits_(nbits) {}
DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBits"},
block_size_(block_size),
nbits_(nbits),
has_zero_points_(has_zero_points) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32},
{"K16", ProgramUniformVariableDataType::Uint32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});
{"num_N_tile", ProgramUniformVariableDataType::Uint32},
{"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32});

private:
uint32_t block_size_;
uint32_t nbits_;
bool has_zero_points_;
};

class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
public:
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size, uint32_t nbits) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size), nbits_(nbits) {}
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBitsSmallMProgram"},
tile_size_(tile_size),
nbits_(nbits),
has_zero_points_(has_zero_points) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
Expand All @@ -46,18 +54,22 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
{"K16", ProgramUniformVariableDataType::Uint32},
{"K32", ProgramUniformVariableDataType::Uint32},
{"block_size", ProgramUniformVariableDataType::Uint32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});
{"num_N_tile", ProgramUniformVariableDataType::Uint32},
{"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
uint32_t nbits_;
bool has_zero_points_;
};

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
const Tensor* zero_points,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t block_size,
uint32_t zero_blocks_per_col,
uint32_t min_M_for_tile_optimization,
uint32_t nbits,
onnxruntime::webgpu::ComputeContext& context,
Expand All @@ -69,8 +81,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
uint32_t batch_count,
uint32_t N,
uint32_t K,
uint32_t components_k,
bool has_zero_points);
uint32_t components_k);

} // namespace webgpu
} // namespace contrib
Expand Down
57 changes: 10 additions & 47 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string_view>

#include "contrib_ops/webgpu/quantization/matmul_nbits.h"
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
Expand All @@ -19,45 +20,8 @@ namespace webgpu {

namespace {

std::string ReadZeroPoint(uint32_t nbits, bool has_zero_points) {
ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits");
std::stringstream ss;
if (has_zero_points) {
ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n"
<< "const bits = " << nbits << "u;\n";
ss << R"(
fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t {
if (row < r_dim && col < c_dim) {
let offset = row * c_dim + col;

// u32 holds elements_in_uint32 packed nbits.
let array_index = offset / elements_in_uint32;
let component_index = offset % elements_in_uint32;
let packed_value = zero_points[array_index];

// Extract the nbits component
let shift_amount = component_index * bits;
)";
ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n";
ss << R"(
return output_element_t(masked_value);
}
return output_element_t(0);
}
)";
} else {
ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n";
ss << R"(
fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t {
// The default zero point is 8.
return output_element_t(default_zero_point);
}
)";
}
return ss.str();
}

constexpr unsigned int kMinMForTileOptimization = 4;

} // namespace

ONNX_OPERATOR_KERNEL_EX(
Expand Down Expand Up @@ -134,7 +98,7 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal
<< " }\n"
<< " return output_element_t(0);\n"
<< "}\n"
<< ReadZeroPoint(nbits_, has_zero_points_);
<< GenerateZeroPointReadingCode(nbits_, has_zero_points_);

shader.AdditionalImplementation() << "\n"
<< "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n"
Expand Down Expand Up @@ -272,7 +236,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}
}
)ADDNL_FN"
<< ReadZeroPoint(nbits_, has_zero_points_);
<< GenerateZeroPointReadingCode(nbits_, has_zero_points_);

shader.MainFunctionBody() << R"MAIN_FN(
let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile);
Expand Down Expand Up @@ -417,24 +381,23 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const uint32_t components_a = GetMaxComponents(K);
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
uint32_t components = GetMaxComponents(N);
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits.
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1;

const bool has_zero_points = zero_points != nullptr;
// macOS - Experimental dawn support for subgroup matrix matmul on Metal.
if (M >= kMinMForTileOptimization &&
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, nbits, context, y);
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, nbits, zero_blocks_per_col, context, y);
}

// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() ||
context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, nbits, context, y);
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, nbits, context, y);
}

// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits.
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1;

// WideTileProgram
// This program is optimized for Block32 prefill using Tile16x128.
const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization;
Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <sstream>
#include <string>

namespace onnxruntime {
namespace contrib {
namespace webgpu {

/**
* Generates WebGPU shader code for reading zero points in quantized matrix multiplication
*
* @param nbits Number of bits for quantization (4 or 8)
* @param has_zero_points Whether zero points are provided as an input
* @param output_type Type name to use for zero point values in the generated code (default: "output_element_t")
* @return String containing the generated WebGPU shader code
*/
inline std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points,
const std::string& output_type = "output_element_t") {
ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits");
std::stringstream ss;

if (has_zero_points) {
ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n"
<< "const bits = " << nbits << "u;\n";
ss << R"(
fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )"
<< output_type << R"( {
if (row < r_dim && col < c_dim) {
let offset = row * c_dim + col;

// u32 holds elements_in_uint32 packed nbits.
let array_index = offset / elements_in_uint32;
let component_index = offset % elements_in_uint32;
let packed_value = zero_points[array_index];

// Extract the nbits component
let shift_amount = component_index * bits;
)";
ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n";
ss << R"(
return )"
<< output_type << R"((masked_value);
}
return )"
<< output_type << R"((0);
}
)";
} else {
ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n";
ss << R"(
fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )"
<< output_type << R"( {
return )"
<< output_type << R"((default_zero_point);
}
)";
}

return ss.str();
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -11,6 +12,9 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
if (has_zero_points_) {
shader.AddInput("zero_points", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

// tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm)
Expand Down Expand Up @@ -41,7 +45,8 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader
tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]);
}
}
)ADDNL_FN";
)ADDNL_FN"
<< GenerateZeroPointReadingCode(nbits_, has_zero_points_, "compute_precision");
if (nbits_ == 4) {
shader.AdditionalImplementation() << R"ADDNL_FN(
fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) {
Expand All @@ -54,12 +59,13 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader
// 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread.
// Stored in column major fashion.
let b_idx = u32((b_global*uniforms.K + k_idx + col)/8);
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col);
for (var step:u32 = 0; step < 2; step++)
{
var b_value = input_b[b_idx+step];
var b_value_lower = (vec4<compute_precision>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
var b_value_upper = (vec4<compute_precision>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<compute_precision>(8)) * scale;
var b_value_lower = (vec4<compute_precision>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<compute_precision>(zero)) * scale;
var b_value_upper = (vec4<compute_precision>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<compute_precision>(zero)) * scale;
let tile_b_base = row * tile_k + col + step * 8;
tile_B[tile_b_base] = b_value_lower[0];
tile_B[tile_b_base + 1] = b_value_upper[0];
Expand All @@ -85,12 +91,13 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader
// 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread.
// Stored in column major fashion.
let b_idx = u32((b_global*uniforms.K + k_idx + col)/8);
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]);
let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col);
for (var step:u32 = 0; step < 2; step++)
{
var b_value = input_b[b_idx+step];
var b_value0 = (vec4<compute_precision>(unpack4xU8(b_value[0])) - vec4<compute_precision>(128)) * scale;
var b_value1 = (vec4<compute_precision>(unpack4xU8(b_value[1])) - vec4<compute_precision>(128)) * scale;
var b_value0 = (vec4<compute_precision>(unpack4xU8(b_value[0])) - vec4<compute_precision>(zero)) * scale;
var b_value1 = (vec4<compute_precision>(unpack4xU8(b_value[1])) - vec4<compute_precision>(zero)) * scale;
let tile_b_base = row * tile_k + col + step * 8;
tile_B[tile_b_base] = b_value0[0];
tile_B[tile_b_base + 1] = b_value0[1];
Expand Down Expand Up @@ -206,17 +213,20 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader
}

Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
const Tensor* zero_points,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t nbits,
uint32_t zero_blocks_per_col,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y) {
constexpr uint32_t kTileSizeA = 32;
constexpr uint32_t kTileSizeB = 64;
constexpr uint32_t kU32Components = 4;
TensorShape y_shape{1, M, N};
SubgroupMatrixMatMulNBitsProgram mul_program{nbits};
const bool has_zero_points = zero_points != nullptr;
SubgroupMatrixMatMulNBitsProgram mul_program{nbits, has_zero_points};
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(
(N + kTileSizeB - 1) / kTileSizeB,
Expand All @@ -226,9 +236,13 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)}})
{static_cast<uint32_t>(K)},
{zero_blocks_per_col}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1})
.CacheHint(nbits);
.CacheHint(nbits, has_zero_points);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(mul_program);
}

Expand All @@ -237,8 +251,7 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont
uint32_t block_size,
uint32_t batch_count,
uint32_t N,
uint32_t K,
bool has_zero_points) {
uint32_t K) {
#if !defined(__wasm__)
const bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
#else
Expand All @@ -254,8 +267,7 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont
block_size == 32 &&
batch_count == 1 &&
K % 32 == 0 &&
N % 64 == 0 &&
!has_zero_points;
N % 64 == 0;
}
} // namespace webgpu
} // namespace contrib
Expand Down
Loading
Loading