Skip to content

Commit 7896216

Browse files
Merge pull request #874 from Devsh-Graphics-Programming/quick-fix-subgroup-arithmetic
Quick fix to subgroup arithmetic
2 parents cc3bfe1 + 727f9a6 commit 7896216

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

include/nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,29 +81,77 @@ T subgroupExclusiveXor(T value) {
8181
}
8282

8383
template<typename T>
84-
T subgroupMin(T value) {
85-
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
84+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupMin(T value) {
85+
return spirv::groupSMin(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
8686
}
8787
template<typename T>
88-
T subgroupInclusiveMin(T value) {
89-
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
88+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupMin(T value) {
89+
return spirv::groupUMin(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
9090
}
9191
template<typename T>
92-
T subgroupExclusiveMin(T value) {
93-
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
92+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupMin(T value) {
93+
return spirv::groupFMin(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
94+
}
95+
template<typename T>
96+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMin(T value) {
97+
return spirv::groupSMin(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
98+
}
99+
template<typename T>
100+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMin(T value) {
101+
return spirv::groupUMin(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
102+
}
103+
template<typename T>
104+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMin(T value) {
105+
return spirv::groupFMin(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
106+
}
107+
template<typename T>
108+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMin(T value) {
109+
return spirv::groupSMin(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
110+
}
111+
template<typename T>
112+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMin(T value) {
113+
return spirv::groupUMin(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
114+
}
115+
template<typename T>
116+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMin(T value) {
117+
return spirv::groupFMin(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
94118
}
95119

96120
template<typename T>
97-
T subgroupMax(T value) {
98-
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
121+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupMax(T value) {
122+
return spirv::groupSMax(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
123+
}
124+
template<typename T>
125+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupMax(T value) {
126+
return spirv::groupUMax(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
127+
}
128+
template<typename T>
129+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupMax(T value) {
130+
return spirv::groupFMax(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
131+
}
132+
template<typename T>
133+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMax(T value) {
134+
return spirv::groupSMax(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
135+
}
136+
template<typename T>
137+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMax(T value) {
138+
return spirv::groupUMax(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
139+
}
140+
template<typename T>
141+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupInclusiveMax(T value) {
142+
return spirv::groupFMax(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
143+
}
144+
template<typename T>
145+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMax(T value) {
146+
return spirv::groupSMax(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
99147
}
100148
template<typename T>
101-
T subgroupInclusiveMax(T value) {
102-
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
149+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMax(T value) {
150+
return spirv::groupUMax(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
103151
}
104152
template<typename T>
105-
T subgroupExclusiveMax(T value) {
106-
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
153+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> subgroupExclusiveMax(T value) {
154+
return spirv::groupFMax(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
107155
}
108156

109157
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ struct exclusive_scan
8383
exclusive_scan_op_t op;
8484
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
8585

86-
retval[0] = exclusive;
8786
[unroll]
88-
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
87+
for (uint32_t i = ItemsPerInvocation-1; i > 0; i--)
8988
retval[i] = binop(exclusive,retval[i-1]);
89+
retval[0] = exclusive;
9090
return retval;
9191
}
9292
};

0 commit comments

Comments
 (0)