Skip to content

Commit 2ca92f4

Browse files
committed
reduce exclusive scan ops
1 parent a1b2324 commit 2ca92f4

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ struct inclusive_scan
4444
// assert binop_t == BinOp
4545
using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1, native>;
4646

47-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
48-
4947
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
5048
{
5149
binop_t binop;
@@ -71,22 +69,24 @@ struct exclusive_scan
7169
using type_t = typename Params::type_t;
7270
using scalar_t = typename Params::scalar_t;
7371
using binop_t = typename Params::binop_t;
74-
using inclusive_scan_op_t = inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
75-
76-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
72+
using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1, native>;
7773

7874
type_t operator()(type_t value)
7975
{
80-
inclusive_scan_op_t op;
81-
value = op(value);
76+
binop_t binop;
77+
type_t retval;
78+
retval[0] = value[0];
79+
[unroll]
80+
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
81+
retval[i] = binop(retval[i-1], value[i]);
8282

83-
type_t left = glsl::subgroupShuffleUp<type_t>(value,1);
83+
exclusive_scan_op_t op;
84+
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
8485

85-
type_t retval;
86-
retval[0] = hlsl::mix(binop_t::identity, left[ItemsPerInvocation-1], bool(glsl::gl_SubgroupInvocationID()));
86+
retval[0] = exclusive;
8787
[unroll]
8888
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
89-
retval[i] = value[i-1];
89+
retval[i] = binop(exclusive,retval[i-1]);
9090
return retval;
9191
}
9292
};
@@ -99,8 +99,6 @@ struct reduction
9999
using binop_t = typename Params::binop_t;
100100
using op_t = reduction<Params, binop_t, 1, native>;
101101

102-
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
103-
104102
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
105103
{
106104
binop_t binop;

0 commit comments

Comments
 (0)