Skip to content

Commit 68230f4

Browse files
authored
[ET-VK][qconv] Fix depthwise weight_sums sum dimension
Differential Revision: D93511635 Pull Request resolved: #17504
1 parent b80e3da commit 68230f4

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

backends/vulkan/patterns/quantized_convolution.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,17 @@ def make_q8ta_conv2d_custom_op(
215215
with graph_module.graph.inserting_before(first_graph_node):
216216
qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
217217
# Pre-compute the weight sums which are needed to apply activation zero point
218-
# when using integer accumulation. For the reshaped 2D weight matrix (IC_per_group * H * W, OC),
219-
# sum over dimension 0 to get sums per output channel
220-
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()
218+
# when using integer accumulation. Sum all weight elements per output channel.
219+
if is_depthwise_conv:
220+
# weight_tensor shape is (H, W, OC); sum over spatial dims (H, W)
221+
sum_per_output_channel = (
222+
weight_tensor.sum(dim=(0, 1)).to(torch.int32).contiguous()
223+
)
224+
else:
225+
# weight_tensor shape is (OC, H*W*IC_per_group); sum over dim 1
226+
sum_per_output_channel = (
227+
weight_tensor.sum(dim=1).to(torch.int32).contiguous()
228+
)
221229
sums_name = qweight_tensor_name + "_sums"
222230
# Sanitize the name
223231
sums_name = sums_name.replace(".", "_")

0 commit comments

Comments
 (0)