Skip to content

Commit f44bd46

Browse files
committed
Optimize dot4{I, U}8Packed for all spv versions
Emit optimized code for `dot4{I, U}8Packed` regardless of SPIR-V version as long as the required capabilities are available. On SPIR-V < 1.6, require the extension "SPV_KHR_integer_dot_product" for this. On SPIR-V >= 1.6, don't require the extension because the corresponding capabilities are part of SPIR-V >= 1.6 proper.
1 parent be9debd commit f44bd46

11 files changed

+122
-108
lines changed

naga/src/back/spv/block.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,17 +1143,21 @@ impl BlockContext<'_> {
11431143
),
11441144
},
11451145
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146-
if self.writer.lang_version() >= (1, 6)
1147-
&& self
1148-
.writer
1149-
.require_all(&[
1150-
spirv::Capability::DotProduct,
1151-
spirv::Capability::DotProductInput4x8BitPacked,
1152-
])
1153-
.is_ok()
1146+
if self
1147+
.writer
1148+
.require_all(&[
1149+
spirv::Capability::DotProduct,
1150+
spirv::Capability::DotProductInput4x8BitPacked,
1151+
])
1152+
.is_ok()
11541153
{
11551154
// Write optimized code using `PackedVectorFormat4x8Bit`.
1156-
self.writer.use_extension("SPV_KHR_integer_dot_product");
1155+
if self.writer.lang_version() < (1, 6) {
1156+
// SPIR-V 1.6 supports the required capabilities natively, so the extension
1157+
// is only required for earlier versions. See right column of
1158+
// <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
1159+
self.writer.use_extension("SPV_KHR_integer_dot_product");
1160+
}
11571161

11581162
let op = match fun {
11591163
Mf::Dot4I8Packed => spirv::Op::SDot,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` by enabling the
2+
# required capabilities on a SPIR-V version where these capabilities are only
3+
# available via the extension "SPV_KHR_integer_dot_product".
4+
5+
targets = "SPIRV"
6+
7+
[spv]
8+
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
9+
version = [1, 0]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
2+
# using a version of SPIR-V / shader model that supports these without any extensions.
3+
4+
targets = "SPIRV | HLSL"
5+
6+
[spv]
7+
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
8+
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
9+
version = [1, 6]
10+
11+
[hlsl]
12+
shader_model = "V6_4"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
fn test_packed_integer_dot_product() -> u32 {
2+
let a_5 = 1u;
3+
let b_5 = 2u;
4+
let c_5: i32 = dot4I8Packed(a_5, b_5);
5+
6+
let a_6 = 3u;
7+
let b_6 = 4u;
8+
let c_6: u32 = dot4U8Packed(a_6, b_6);
9+
10+
// test baking of arguments
11+
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
12+
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
13+
return c_8;
14+
}
15+
16+
@compute @workgroup_size(1)
17+
fn main() {
18+
let c = test_packed_integer_dot_product();
19+
}

naga/tests/in/wgsl/functions-optimized.toml

Lines changed: 0 additions & 11 deletions
This file was deleted.

naga/tests/out/spv/wgsl-functions-optimized.spvasm renamed to naga/tests/out/spv/wgsl-functions-optimized-by-capability.spvasm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; SPIR-V
2-
; Version: 1.6
2+
; Version: 1.0
33
; Generator: rspirv
44
; Bound: 30
55
OpCapability Shader
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: rspirv
4+
; Bound: 30
5+
OpCapability Shader
6+
OpCapability DotProductKHR
7+
OpCapability DotProductInput4x8BitPackedKHR
8+
%1 = OpExtInstImport "GLSL.std.450"
9+
OpMemoryModel Logical GLSL450
10+
OpEntryPoint GLCompute %26 "main"
11+
OpExecutionMode %26 LocalSize 1 1 1
12+
%2 = OpTypeVoid
13+
%3 = OpTypeInt 32 0
14+
%6 = OpTypeFunction %3
15+
%7 = OpConstant %3 1
16+
%8 = OpConstant %3 2
17+
%9 = OpConstant %3 3
18+
%10 = OpConstant %3 4
19+
%11 = OpConstant %3 5
20+
%12 = OpConstant %3 6
21+
%13 = OpConstant %3 7
22+
%14 = OpConstant %3 8
23+
%16 = OpTypeInt 32 1
24+
%27 = OpTypeFunction %2
25+
%5 = OpFunction %3 None %6
26+
%4 = OpLabel
27+
OpBranch %15
28+
%15 = OpLabel
29+
%17 = OpSDotKHR %16 %7 %8 PackedVectorFormat4x8BitKHR
30+
%18 = OpUDotKHR %3 %9 %10 PackedVectorFormat4x8BitKHR
31+
%19 = OpIAdd %3 %11 %18
32+
%20 = OpIAdd %3 %12 %18
33+
%21 = OpSDotKHR %16 %19 %20 PackedVectorFormat4x8BitKHR
34+
%22 = OpIAdd %3 %13 %18
35+
%23 = OpIAdd %3 %14 %18
36+
%24 = OpUDotKHR %3 %22 %23 PackedVectorFormat4x8BitKHR
37+
OpReturnValue %24
38+
OpFunctionEnd
39+
%26 = OpFunction %2 None %27
40+
%25 = OpLabel
41+
OpBranch %28
42+
%28 = OpLabel
43+
%29 = OpFunctionCall %3 %5
44+
OpReturn
45+
OpFunctionEnd

0 commit comments

Comments
 (0)