Skip to content

Commit 8e118f5

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][ez][qconv] Add auto-selection to prefer im2col for q8ta_conv2d
Pull Request resolved: #17568 The q8ta_conv2d operator previously always delegated to the general (sliding window) implementation, even though the im2col implementation is 2-5x faster for non-grouped convolutions with in_channels % 4 == 0. This change adds runtime auto-selection logic that checks the groups parameter and input channel alignment, then dispatches to q8ta_conv2d_im2col when its constraints are met. On ResNet50 int8, this reduces Vulkan inference latency from 14.2ms to 6.8ms (2.1x speedup) on Samsung Galaxy S24, making it 30% faster than XNNPACK (9.7ms). Also adds performance test cases for deep-channel small-spatial scenarios (512ch 7x7, 1024→2048ch 1x1 stride-2) that stress-test the optimization. ghstack-source-id: 343460520 @exported-using-ghexport Differential Revision: [D93768637](https://our.internmc.facebook.com/intern/diff/D93768637/)
1 parent 4a06a4f commit 8e118f5

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,30 @@ void q8ta_conv2d_general(
417417
}
418418

419419
void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
420-
q8ta_conv2d_general(graph, args);
420+
const ValueRef input = args.at(0);
421+
const ValueRef groups_ref = args.at(13);
422+
const ValueRef output = args.at(15);
423+
424+
const int64_t groups = graph.extract_scalar<int64_t>(groups_ref);
425+
const int64_t in_channels = graph.size_at<int64_t>(-3, input);
426+
const int64_t in_channels_per_group = in_channels / groups;
427+
428+
const int64_t H_out = graph.size_at<int64_t>(-2, output);
429+
const int64_t W_out = graph.size_at<int64_t>(-1, output);
430+
const int64_t spatial_out = H_out * W_out;
431+
432+
// Use im2col when the channel depth is sufficient for tiled GEMM to win, or
433+
// when the output spatial area is small enough that the im2col buffer stays
434+
// manageable. For large spatial outputs with few channels, the im2col buffer
435+
// becomes too large and the general shader is more efficient.
436+
const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 &&
437+
(in_channels_per_group >= 64 || spatial_out <= 4096);
438+
439+
if (use_im2col) {
440+
q8ta_conv2d_im2col(graph, args);
441+
} else {
442+
q8ta_conv2d_general(graph, args);
443+
}
421444
}
422445

423446
REGISTER_OPERATORS {

backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,23 @@ static std::vector<TestCase> generate_quantized_conv2d_test_cases() {
378378
Stride(2, 2),
379379
Padding(2, 2),
380380
Dilation(1, 1),
381-
4}};
381+
4},
382+
// Deep channels + small spatial (ResNet50 stage 5 bottleneck)
383+
{OutInChannels(512, 512),
384+
InputSize2D(7, 7),
385+
KernelSize(3, 3),
386+
Stride(1, 1),
387+
Padding(1, 1),
388+
Dilation(1, 1),
389+
1},
390+
// Strided 1x1 shortcut (worst-case strided downsample)
391+
{OutInChannels(2048, 1024),
392+
InputSize2D(14, 14),
393+
KernelSize(1, 1),
394+
Stride(2, 2),
395+
Padding(0, 0),
396+
Dilation(1, 1),
397+
1}};
382398

383399
// Test with different storage types and memory layouts
384400
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};

0 commit comments

Comments
 (0)