Skip to content

Commit 2324f45

Browse files
Merge pull request #876 from Devsh-Graphics-Programming/improve-workgroup-scan-2
Improvements to workgroup reduce + scan
2 parents 3b3d45c + 029cfeb commit 2324f45

File tree

15 files changed

+959
-48
lines changed

15 files changed

+959
-48
lines changed

include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
33

4-
#include "nbl/builtin/hlsl/concepts.hlsl"
4+
#include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl"
55
#include "nbl/builtin/hlsl/fft/common.hlsl"
66

77
namespace nbl
@@ -17,49 +17,15 @@ namespace fft
1717
// * void set(uint32_t index, in uint32_t value);
1818
// * void workgroupExecutionAndMemoryBarrier();
1919

20-
#define NBL_CONCEPT_NAME FFTSharedMemoryAccessor
21-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
22-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
23-
#define NBL_CONCEPT_PARAM_0 (accessor, T)
24-
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
25-
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
26-
NBL_CONCEPT_BEGIN(3)
27-
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
28-
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
29-
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
30-
NBL_CONCEPT_END(
31-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t, uint32_t>(index, val)), is_same_v, void))
32-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t, uint32_t>(index, val)), is_same_v, void))
33-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
34-
);
35-
#undef val
36-
#undef index
37-
#undef accessor
38-
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
39-
20+
template<typename T, typename V=uint32_t, typename I=uint32_t>
21+
NBL_BOOL_CONCEPT FFTSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T,V,I>;
4022

4123
// The Accessor (for a small FFT) MUST provide the following methods:
4224
// * void get(uint32_t index, NBL_REF_ARG(complex_t<Scalar>) value);
4325
// * void set(uint32_t index, in complex_t<Scalar> value);
4426

45-
#define NBL_CONCEPT_NAME FFTAccessor
46-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
47-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
48-
#define NBL_CONCEPT_PARAM_0 (accessor, T)
49-
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
50-
#define NBL_CONCEPT_PARAM_2 (val, complex_t<Scalar>)
51-
NBL_CONCEPT_BEGIN(3)
52-
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
53-
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
54-
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
55-
NBL_CONCEPT_END(
56-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<complex_t<Scalar> >(index, val)), is_same_v, void))
57-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<complex_t<Scalar> >(index, val)), is_same_v, void))
58-
);
59-
#undef val
60-
#undef index
61-
#undef accessor
62-
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
27+
template<typename T, typename Scalar, typename I=uint32_t>
28+
NBL_BOOL_CONCEPT FFTAccessor = concepts::accessors::GenericDataAccessor<T,complex_t<Scalar>,I>;
6329

6430
}
6531
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_GENERIC_SHARED_DATA_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_GENERIC_SHARED_DATA_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/concepts.hlsl"
5+
6+
namespace nbl
7+
{
8+
namespace hlsl
9+
{
10+
namespace concepts
11+
{
12+
namespace accessors
13+
{
14+
15+
#define NBL_CONCEPT_NAME GenericSharedMemoryAccessor
16+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)(typename)
17+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(V)(I)
18+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
19+
#define NBL_CONCEPT_PARAM_1 (val, V)
20+
#define NBL_CONCEPT_PARAM_2 (index, I)
21+
NBL_CONCEPT_BEGIN(3)
22+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
23+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
24+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
25+
NBL_CONCEPT_END(
26+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<V,I>(index, val)), is_same_v, void))
27+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<V,I>(index, val)), is_same_v, void))
28+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
29+
);
30+
#undef val
31+
#undef index
32+
#undef accessor
33+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
34+
35+
#define NBL_CONCEPT_NAME GenericReadAccessor
36+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)(typename)
37+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(V)(I)
38+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
39+
#define NBL_CONCEPT_PARAM_1 (val, V)
40+
#define NBL_CONCEPT_PARAM_2 (index, I)
41+
NBL_CONCEPT_BEGIN(3)
42+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
43+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
44+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
45+
NBL_CONCEPT_END(
46+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<V,I>(index, val)), is_same_v, void))
47+
);
48+
#undef val
49+
#undef index
50+
#undef accessor
51+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
52+
53+
#define NBL_CONCEPT_NAME GenericWriteAccessor
54+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)(typename)
55+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(V)(I)
56+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
57+
#define NBL_CONCEPT_PARAM_1 (val, V)
58+
#define NBL_CONCEPT_PARAM_2 (index, I)
59+
NBL_CONCEPT_BEGIN(3)
60+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
61+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
62+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
63+
NBL_CONCEPT_END(
64+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<V,I>(index, val)), is_same_v, void))
65+
);
66+
#undef val
67+
#undef index
68+
#undef accessor
69+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
70+
71+
template<typename T, typename V, typename I=uint32_t>
72+
NBL_BOOL_CONCEPT GenericDataAccessor = GenericWriteAccessor<T,V,I> && GenericWriteAccessor<T,V,I>;
73+
74+
}
75+
}
76+
}
77+
}
78+
79+
#endif
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl"
5+
6+
namespace nbl
7+
{
8+
namespace hlsl
9+
{
10+
namespace workgroup2
11+
{
12+
13+
template<typename T, typename V, typename I=uint32_t>
14+
NBL_BOOL_CONCEPT ArithmeticSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T,V,I>;
15+
16+
template<typename T, typename V, typename I=uint32_t>
17+
NBL_BOOL_CONCEPT ArithmeticReadOnlyDataAccessor = concepts::accessors::GenericReadAccessor<T,V,I>;
18+
19+
template<typename T, typename V, typename I=uint32_t>
20+
NBL_BOOL_CONCEPT ArithmeticDataAccessor = concepts::accessors::GenericDataAccessor<T,V,I>;
21+
22+
}
23+
}
24+
}
25+
26+
#endif

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ struct StructureOfArrays : impl::StructureOfArraysBase<IndexType,ElementStride,S
112112
BaseAccessor accessor;
113113

114114
// Question: shall we go back to requiring a `access_t get(index_t)` on the `BaseAccessor`, then we could `enable_if` check the return type (via `has_method_get`) matches and we won't get Nasty HLSL copy-in copy-out conversions
115-
template<typename T>
116-
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> get(const index_t ix, NBL_REF_ARG(T) value)
115+
template<typename T, typename I=index_t>
116+
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> get(const I ix, NBL_REF_ARG(T) value)
117117
{
118118
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
119119
// `vector` for now, we'll use `array` later when `bit_cast` gets fixed
@@ -123,8 +123,8 @@ struct StructureOfArrays : impl::StructureOfArraysBase<IndexType,ElementStride,S
123123
value = bit_cast<T,vector<access_t,SubElementCount> >(aux);
124124
}
125125

126-
template<typename T>
127-
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> set(const index_t ix, NBL_CONST_REF_ARG(T) value)
126+
template<typename T, typename I=index_t>
127+
enable_if_t<sizeof(T)%sizeof(access_t)==0,void> set(const I ix, NBL_CONST_REF_ARG(T) value)
128128
{
129129
NBL_CONSTEXPR uint64_t SubElementCount = sizeof(T)/sizeof(access_t);
130130
// `vector` for now, we'll use `array` later when `bit_cast` gets fixed
@@ -209,11 +209,11 @@ struct Offset : impl::OffsetBase<IndexType,_Offset>
209209

210210
BaseAccessor accessor;
211211

212-
template <typename T>
213-
void set(index_t idx, T value) {accessor.set(idx+base_t::offset,value); }
212+
template <typename T, typename I=index_t>
213+
void set(I idx, T value) {accessor.set(idx+base_t::offset,value); }
214214

215-
template <typename T>
216-
void get(index_t idx, NBL_REF_ARG(T) value) {accessor.get(idx+base_t::offset,value);}
215+
template <typename T, typename I=index_t>
216+
void get(I idx, NBL_REF_ARG(T) value) {accessor.get(idx+base_t::offset,value);}
217217

218218
template<typename S=BaseAccessor>
219219
enable_if_t<

include/nbl/builtin/hlsl/subgroup2/ballot.hlsl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,29 @@
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
66

7+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_ballot.hlsl"
8+
79
namespace nbl
810
{
911
namespace hlsl
1012
{
1113
namespace subgroup2
1214
{
1315

16+
template<int32_t AssumeAllActive=false>
17+
uint32_t LastSubgroupInvocation()
18+
{
19+
if (AssumeAllActive)
20+
return glsl::gl_SubgroupSize()-1;
21+
else
22+
return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true));
23+
}
24+
25+
bool ElectLast()
26+
{
27+
return glsl::gl_SubgroupInvocationID()==LastSubgroupInvocation();
28+
}
29+
1430
template<uint32_t SubgroupSizeLog2>
1531
struct Configuration
1632
{

include/nbl/builtin/hlsl/tuple.hlsl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_TUPLE_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_TUPLE_INCLUDED_
6+
7+
#include "nbl/builtin/hlsl/type_traits.hlsl"
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
14+
template<typename T0, typename T1=void, typename T2=void> // TODO: in the future use BOOST_PP to make this
15+
struct tuple
16+
{
17+
T0 t0;
18+
T1 t1;
19+
T2 t2;
20+
};
21+
22+
template<uint32_t N, typename Tuple>
23+
struct tuple_element;
24+
25+
template<typename T0>
26+
struct tuple<T0,void,void>
27+
{
28+
T0 t0;
29+
};
30+
31+
template<typename T0, typename T1>
32+
struct tuple<T0,T1,void>
33+
{
34+
T0 t0;
35+
T1 t1;
36+
};
37+
// specializations for less and less void elements
38+
39+
// base case
40+
template<typename Head, typename T1, typename T2>
41+
struct tuple_element<0,tuple<Head,T1,T2> >
42+
{
43+
using type = Head;
44+
};
45+
46+
template<typename T0, typename Head, typename T2>
47+
struct tuple_element<1,tuple<T0,Head,T2> >
48+
{
49+
using type = Head;
50+
};
51+
52+
template<typename T0, typename T1, typename Head>
53+
struct tuple_element<2,tuple<T0,T1,Head> >
54+
{
55+
using type = Head;
56+
};
57+
58+
}
59+
}
60+
61+
#endif

include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct vector_traits<vector<T, DIMENSION> >\
2828
NBL_CONSTEXPR_STATIC_INLINE bool IsVector = true;\
2929
};\
3030

31+
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(1)
3132
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(2)
3233
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(3)
3334
DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(4)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_
6+
7+
8+
#include "nbl/builtin/hlsl/functional.hlsl"
9+
#include "nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl"
10+
#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl"
11+
12+
13+
namespace nbl
14+
{
15+
namespace hlsl
16+
{
17+
namespace workgroup2
18+
{
19+
20+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
21+
struct reduction
22+
{
23+
using scalar_t = typename BinOp::type_t;
24+
25+
template<class ReadOnlyDataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticReadOnlyDataAccessor<ReadOnlyDataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t>)
26+
static scalar_t __call(NBL_REF_ARG(ReadOnlyDataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
27+
{
28+
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
29+
return fn.template __call<ReadOnlyDataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
30+
}
31+
};
32+
33+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
34+
struct inclusive_scan
35+
{
36+
using scalar_t = typename BinOp::type_t;
37+
38+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t>)
39+
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
40+
{
41+
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn;
42+
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
43+
}
44+
};
45+
46+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
47+
struct exclusive_scan
48+
{
49+
using scalar_t = typename BinOp::type_t;
50+
51+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t>)
52+
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
53+
{
54+
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn;
55+
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
56+
}
57+
};
58+
59+
}
60+
}
61+
}
62+
63+
#endif

0 commit comments

Comments
 (0)