Skip to content

Commit d984a2c

Browse files
pytorchbotmorelos
andauthored
[ET-VK][Ops] quantization op shaders and impl (pytorch#11767)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#11369 by @ahmtox ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/11/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/11/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/11/orig @diff-train-skip-merge Co-authored-by: morelos <morelos@devvm4573.ash0.facebook.com>
1 parent a6d8440 commit d984a2c

File tree

8 files changed

+926
-9
lines changed

8 files changed

+926
-9
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef QUANTIZE_GLSLH
10+
#define QUANTIZE_GLSLH
11+
12+
OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) {
13+
float inv_scale = 1.0 / scale_val;
14+
15+
float rounded_float = round(inv_scale * float(value));
16+
17+
int qvalue = zero_point_val + int(rounded_float);
18+
19+
qvalue = max(qvalue, quant_min);
20+
qvalue = min(qvalue, quant_max);
21+
22+
return OUT_T(qvalue);
23+
}
24+
25+
#endif // QUANTIZE_GLSLH
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
15+
16+
#define ${MODE}
17+
18+
${define_active_storage_type("buffer")}
19+
${define_required_extensions(IN_DTYPE)}
20+
${define_required_extensions(OUT_DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
#include "indexing_utils.h"
25+
26+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
28+
29+
$if MODE == "per_tensor":
30+
layout(push_constant) uniform restrict Block {
31+
float scale;
32+
int zero_point;
33+
int quant_min;
34+
int quant_max;
35+
};
36+
$if MODE == "per_token":
37+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
38+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
39+
40+
layout(push_constant) uniform restrict Block {
41+
int num_tokens;
42+
int quant_min;
43+
int quant_max;
44+
};
45+
46+
${layout_declare_ubo(B, "int", "out_numel")}
47+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
48+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
49+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
50+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
51+
52+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
53+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
54+
55+
#include "quantize.glslh"
56+
57+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
58+
59+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
60+
const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);
61+
62+
/*
63+
* QUANTIZATION SHADER (BUFFER STORAGE)
64+
*
65+
* This shader converts floating-point tensor values to n-bit integer representations
66+
* using pre-computed quantization parameters (scale and zero_point). The quantization
67+
* maps floating-point values to a discrete integer range while preserving the
68+
* original data distribution as much as possible.
69+
*
70+
* ALGORITHM:
71+
* 1. Load floating-point input value from buffer
72+
* 2. Apply quantization formula: qvalue = round(value / scale) + zero_point
73+
* 3. Clamp result to [quant_min, quant_max] range
74+
* 4. Store quantized integer value to output buffer
75+
*
76+
* WORKGROUP CONFIGURATION:
77+
* - Per-Tensor Mode:
78+
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
79+
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
80+
* - Per-Token Mode:
81+
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
82+
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
83+
*
84+
* SUPPORTED CONFIGURATIONS:
85+
* - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access
86+
* - and supports any tensor layout through stride calculations and dimension ordering
87+
* - Per-Token Config: Assumes width-packed layout (packed_dim = 0)
88+
* - since that is how token index is calculated
89+
*
90+
* QUANTIZATION FORMULA VISUALIZATION:
91+
* For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]:
92+
*
93+
* Floating Point Domain: Integer Domain:
94+
* min_val ────────────────► quant_min
95+
* │ │
96+
* │ scale = (max_val - min_val) / (quant_max - quant_min)
97+
* │ zero_point = quant_min - round(min_val / scale)
98+
* │ │
99+
* max_val ────────────────► quant_max
100+
*
101+
* Quantization Process:
102+
* Input: 2.5 (float)
103+
* Step 1: value / scale = 2.5 / 0.1 = 25.0
104+
* Step 2: round(25.0) + zero_point = 25 + (-128) = -103
105+
* Step 3: clamp(-103, -128, 127) = -103
106+
* Output: -103 (int8)
107+
*
108+
* PER-TENSOR QUANTIZATION:
109+
* - Single scale and zero_point values for entire tensor
110+
* - All elements use same quantization parameters
111+
* - Parameters passed as push constants for efficiency
112+
* - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)
113+
*
114+
* PER-TOKEN QUANTIZATION:
115+
* - Separate scale and zero_point for each token
116+
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
117+
* - Parameters stored in buffer arrays indexed by token_id
118+
* - Each thread calculates its token_id from tensor coordinates
119+
* - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max)
120+
*/
121+
122+
#ifdef per_tensor
123+
124+
void quantize_per_tensor() {
125+
const int out_bufi = int(gl_GlobalInvocationID.x);
126+
127+
if (out_bufi >= out_numel) {
128+
return;
129+
}
130+
131+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
132+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
133+
134+
IN_T value = t_in[in_bufi];
135+
OUT_T qvalue = quantize_val(value, scale, zero_point);
136+
137+
t_out[out_bufi] = qvalue;
138+
}
139+
140+
#else
141+
142+
void quantize_per_token() {
143+
const int out_bufi = int(gl_GlobalInvocationID.x);
144+
145+
if (out_bufi >= out_numel) {
146+
return;
147+
}
148+
149+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
150+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
151+
152+
IN_T value = t_in[in_bufi];
153+
154+
int token_idx = 0;
155+
156+
if (t_out_sizes.w > 1) {
157+
// 4D tensor
158+
token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y;
159+
} else if (t_out_sizes.z > 1) {
160+
// 3D tensor
161+
token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y;
162+
} else if (t_out_sizes.y > 1) {
163+
// 2D tensor
164+
token_idx = out_tidx.y;
165+
}
166+
// For 1D tensor, token_idx remains 0
167+
168+
token_idx = min(token_idx, num_tokens - 1);
169+
170+
OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]);
171+
172+
t_out[out_bufi] = qvalue;
173+
}
174+
175+
#endif
176+
177+
void main() {
178+
quantize_${MODE}();
179+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
quantize_buffer:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: float
4+
OUT_DTYPE: int32
5+
MODE: per_tensor
6+
generate_variant_forall:
7+
IN_DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
OUT_DTYPE:
11+
- VALUE: uint8
12+
- VALUE: int8
13+
- VALUE: int32
14+
shader_variants:
15+
- NAME: quantize_per_tensor_buffer
16+
MODE: per_tensor
17+
- NAME: quantize_per_token_buffer
18+
MODE: per_token

0 commit comments

Comments
 (0)