Skip to content

Commit 116a75a

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Add fused q8ta_relu unary operator for int8x4 tensors
Pull Request resolved: #17507 This adds a fused quantized unary operator (ReLU) that operates directly on int8x4 packed buffer tensors, avoiding the overhead of separate dequantize-relu-requantize dispatches. The implementation follows the same pattern as q8ta_binary: a single GLSL compute shader dequantizes int8x4 blocks to float, applies the unary operation, and requantizes back to int8x4 in one dispatch. The shader uses the OPERATOR macro for parameterization so additional unary ops can be added as YAML variants without new shader code. Components added: - GLSL shader (q8ta_unary.glsl) and YAML config with relu variant - C++ operator implementation (Q8taUnary.cpp/h) registering et_vk.q8ta_relu.default - Export graph fusion pattern (quantized_unary.py) that detects dequant->relu->quant sequences and replaces them with the fused op - Custom op definition (q8ta_relu in custom_ops_lib.py) for the export pipeline - Test harness (TestQ8taUnary.cpp, test_q8ta_unary.cpp) with reference implementation and coverage across multiple shapes and quantized layouts This diff was authored with Claude. ghstack-source-id: 342806073 @exported-using-ghexport Differential Revision: [D93511629](https://our.internmc.facebook.com/intern/diff/D93511629/)
1 parent ca20b0e commit 116a75a

File tree

11 files changed

+800
-1
lines changed

11 files changed

+800
-1
lines changed

backends/vulkan/custom_ops_lib.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,41 @@ def q8ta_add_impl(
616616
lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd")
617617
q8ta_add_op = getattr(getattr(torch.ops, namespace), name)
618618

619+
########################
620+
## q8ta_relu ##
621+
########################
622+
623+
624+
def q8ta_relu_impl(
625+
input: torch.Tensor,
626+
input_scale: float,
627+
input_zero_point: int,
628+
output_scale: float,
629+
output_zero_point: int,
630+
):
631+
# Dequantize input to float
632+
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
633+
input, input_scale, input_zero_point, -128, 127, input.dtype
634+
)
635+
636+
# Apply ReLU
637+
result = torch.nn.functional.relu(dequant)
638+
639+
# Quantize the result back to int8
640+
quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor(
641+
result, output_scale, output_zero_point, -128, 127, torch.int8
642+
)
643+
644+
return quantized_result
645+
646+
647+
name = "q8ta_relu"
648+
lib.define(
649+
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor"
650+
)
651+
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
652+
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)
653+
619654
#############################
620655
## select_as_symint ##
621656
#############################

backends/vulkan/op_registry.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,19 @@ def register_q8ta_add():
514514

515515

516516
# =============================================================================
517-
# Reduce.cpp
517+
# Q8taUnary.cpp
518+
# =============================================================================
519+
520+
521+
@update_features(exir_ops.edge.et_vk.q8ta_relu.default)
522+
def register_q8ta_relu():
523+
return OpFeatures(
524+
inputs_storage=utils.PACKED_INT8_BUFFER,
525+
supports_resize=True,
526+
)
527+
528+
529+
# =============================================================================
518530
# =============================================================================
519531

520532

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library,
1313
"quantized_linear.py",
1414
"quantized_convolution.py",
1515
"quantized_binary.py",
16+
"quantized_unary.py",
1617
"sdpa.py",
1718
"select_as_symint.py",
1819
],

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import executorch.backends.vulkan.patterns.quantized_linear # noqa
1414

15+
import executorch.backends.vulkan.patterns.quantized_unary # noqa
16+
1517
import executorch.backends.vulkan.patterns.rope # noqa
1618

1719
import executorch.backends.vulkan.patterns.sdpa # noqa
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
22+
23+
class QuantizedUnaryMatch(PatternMatch):
24+
def __init__(self, unary_node: torch.fx.Node) -> None:
25+
self.anchor_node = unary_node
26+
self.match_found = False
27+
self.all_nodes = [self.anchor_node]
28+
29+
# The unary op takes a single input which must be a dequantize node
30+
if len(unary_node.args) < 1:
31+
return
32+
33+
input_node = unary_node.args[0]
34+
assert isinstance(input_node, torch.fx.Node)
35+
36+
if not utils.is_dequant_node(input_node):
37+
return
38+
39+
self.dequantize_input_node = input_node
40+
41+
# Extract quantization parameters for the input
42+
self.quantize_input_node = self.dequantize_input_node.args[0]
43+
self.input_scales_node = self.dequantize_input_node.args[1]
44+
self.input_zeros_node = self.dequantize_input_node.args[2]
45+
46+
self.all_nodes.append(self.dequantize_input_node)
47+
48+
# The unary op output must have exactly one user: a quantize node
49+
self.output_node = self.anchor_node
50+
51+
if len(self.output_node.users) != 1:
52+
return
53+
54+
cur_node = list(self.output_node.users)[0]
55+
56+
if not utils.is_quant_node(cur_node):
57+
return
58+
59+
self.quantize_output_node = cur_node
60+
self.output_scales_node = self.quantize_output_node.args[1]
61+
self.output_zeros_node = self.quantize_output_node.args[2]
62+
63+
self.all_nodes.append(self.quantize_output_node)
64+
65+
self.match_found = True
66+
67+
68+
# Unary operation anchor nodes that we support
69+
unary_anchor_nodes = {
70+
exir_ops.edge.aten.relu.default,
71+
}
72+
73+
74+
@register_pattern_detector("quantized_unary")
75+
def find_quantized_unary_patterns(
76+
node: torch.fx.Node,
77+
) -> Optional[QuantizedUnaryMatch]:
78+
if node.target not in unary_anchor_nodes:
79+
return None
80+
81+
matched_pattern = QuantizedUnaryMatch(node)
82+
if matched_pattern.match_found:
83+
return matched_pattern
84+
85+
return None
86+
87+
88+
##
89+
## Pattern Replacement
90+
##
91+
92+
93+
@register_pattern_replacement("quantized_unary")
94+
def make_q8ta_unary_custom_op(
95+
ep: ExportedProgram,
96+
graph_module: torch.fx.GraphModule,
97+
match: QuantizedUnaryMatch,
98+
):
99+
op_target = None
100+
if match.anchor_node.target == exir_ops.edge.aten.relu.default:
101+
op_target = exir_ops.edge.et_vk.q8ta_relu.default
102+
else:
103+
raise NotImplementedError(
104+
f"Unsupported unary operation: {match.anchor_node.target}"
105+
)
106+
107+
with graph_module.graph.inserting_before(match.output_node):
108+
qunary_node = graph_module.graph.create_node(
109+
"call_function",
110+
op_target,
111+
args=(
112+
match.quantize_input_node,
113+
match.input_scales_node,
114+
match.input_zeros_node,
115+
match.output_scales_node,
116+
match.output_zeros_node,
117+
),
118+
)
119+
120+
qunary_node.meta["val"] = match.output_node.meta["val"]
121+
match.quantize_output_node.replace_all_uses_with(qunary_node)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_active_storage_type("buffer")}
14+
15+
#define op(X) ${OPERATOR}
16+
17+
layout(std430) buffer;
18+
19+
#include "indexing.glslh"
20+
#include "common.glslh"
21+
#include "block_indexing.glslh"
22+
#include "block_int8x4_load.glslh"
23+
#include "block_int8x4_store.glslh"
24+
25+
// Output buffer: packed int8x4 values
26+
${layout_declare_tensor(B, "w", "t_out", "int", "buffer")}
27+
// Input buffer: packed int8x4 values
28+
${layout_declare_tensor(B, "r", "t_in", "int", "buffer")}
29+
30+
// Metadata for output and input tensors
31+
${layout_declare_ubo(B, "BufferMetadata", "out_meta")}
32+
${layout_declare_ubo(B, "BufferMetadata", "in_meta")}
33+
34+
layout(push_constant) uniform restrict Block {
35+
float input_scale;
36+
int input_zp;
37+
float output_inv_scale;
38+
int output_zp;
39+
};
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
44+
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
45+
${layout_declare_spec_const(C, "int", "block_config", "0")}
46+
47+
// Generate loading functions for input buffer
48+
define_load_int8x4_buffer_fns(t_in)
49+
50+
// Generate storing functions for output buffer
51+
define_store_int8x4_buffer_fns(t_out)
52+
53+
void main() {
54+
// Buffer storage: use linear dispatch
55+
const uint contig_block_idx = gl_GlobalInvocationID.x;
56+
TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config(
57+
out_meta, contig_block_idx, block_config);
58+
59+
if (out_of_bounds(tidx, out_meta)) {
60+
return;
61+
}
62+
63+
const int block_outer_dim = get_block_outer_dim(block_config);
64+
65+
// Load int8x4 block from input
66+
ivec4 in_block = load_int8x4_block_from_t_in(
67+
in_meta, tidx, in_layout, block_outer_dim);
68+
69+
ivec4 out_block;
70+
71+
for (int row = 0; row < 4; row++) {
72+
vec4 in_texel = unpack_and_dequantize(
73+
in_block[row], input_scale, input_zp);
74+
75+
vec4 out_texel = op(in_texel);
76+
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
77+
}
78+
79+
// Store to output buffer
80+
store_int8x4_block_to_t_out(
81+
out_meta, tidx, out_layout, block_outer_dim, out_block);
82+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q8ta_unary:
8+
parameter_names_with_default_values:
9+
OPERATOR: X
10+
shader_variants:
11+
- NAME: q8ta_relu_buffer
12+
OPERATOR: max(X, vec4(0.0))

0 commit comments

Comments
 (0)