6
6
7
7
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
8
8
#include "nbl/builtin/hlsl/tuple.hlsl"
9
+ #include "nbl/builtin/hlsl/mpl.hlsl"
9
10
10
11
namespace nbl
11
12
{
@@ -19,23 +20,37 @@ namespace impl
19
20
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2>
20
21
struct virtual_wg_size_log2
21
22
{
22
- NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
23
- NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
23
+ #define DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
24
+ #define DEFINE_VIRTUAL_WG_T (ID) ID
25
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
26
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
27
+ #include "impl/virtual_wg_size_def.hlsl"
28
+ #undef DEFINE_COND_VAL
29
+ #undef DEFINE_MPL_MAX_V
30
+ #undef DEFINE_VIRTUAL_WG_T
31
+ #undef DEFINE_ASSIGN
32
+
33
+ // must have at least enough level 0 outputs to feed a single subgroup
24
34
static_assert (WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize" );
25
35
static_assert (WorkgroupSizeLog2<=SubgroupSizeLog2*3 +4 , "WorkgroupSize cannot be larger than (SubgroupSize^3)*16" );
26
-
27
- NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2 +2 ),uint16_t,3 ,2 >::value,1 >::value;
28
- NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, SubgroupSizeLog2*levels, WorkgroupSizeLog2>;
29
- // must have at least enough level 0 outputs to feed a single subgroup
30
36
};
31
37
32
38
template<class VirtualWorkgroup, uint16_t BaseItemsPerInvocation>
33
39
struct items_per_invocation
34
40
{
35
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v<int16_t,VirtualWorkgroup::WorkgroupSizeLog2-VirtualWorkgroup::SubgroupSizeLog2*VirtualWorkgroup::levels,0 >;
36
- NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
37
- NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t (0x1u) << conditional_value<VirtualWorkgroup::levels==3 , uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2 >, ItemsPerInvocationProductLog2>::value;
38
- NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t (0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2 ,0 >;
41
+ #define DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
42
+ #define DEFINE_VIRTUAL_WG_T (ID) VirtualWorkgroup::ID
43
+ #define DEFINE_ITEMS_INVOC_T (ID) ID
44
+ #define DEFINE_MPL_MIN_V (TYPE,ARG1,ARG2) mpl::min_v<TYPE, ARG1, ARG2>
45
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
46
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
47
+ #include "impl/items_per_invoc_def.hlsl"
48
+ #undef DEFINE_COND_VAL
49
+ #undef DEFINE_MPL_MAX_V
50
+ #undef DEFINE_MPL_MIN_V
51
+ #undef DEFINE_ITEMS_INVOC_T
52
+ #undef DEFINE_VIRTUAL_WG_T
53
+ #undef DEFINE_ASSIGN
39
54
40
55
using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >;
41
56
};
@@ -44,47 +59,35 @@ struct items_per_invocation
44
59
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
45
60
struct ArithmeticConfiguration
46
61
{
47
- NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
48
- NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t (0x1u) << WorkgroupSizeLog2;
49
- NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
50
- NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t (0x1u) << SubgroupSizeLog2;
51
-
52
- using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
53
- NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
54
- NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t (0x1u) << virtual_wg_t::value;
55
- static_assert (VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
56
-
62
+ using virtual_wg_t = impl::virtual_wg_size_log2<_WorkgroupSizeLog2, _SubgroupSizeLog2>;
57
63
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
58
64
using ItemsPerInvocation = typename items_per_invoc_t::ItemsPerInvocation;
59
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = tuple_element<0 ,ItemsPerInvocation>::type::value;
60
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = tuple_element<1 ,ItemsPerInvocation>::type::value;
61
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = tuple_element<2 ,ItemsPerInvocation>::type::value;
62
- static_assert (ItemsPerInvocation_2<=4 , "4 level scan would have been needed with this config!" );
63
65
64
- NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3 ,uint16_t,
65
- mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
66
- SubgroupSize*ItemsPerInvocation_1>::value;
67
- NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize*ItemsPerInvocation_2,0 >::value;
68
- NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualInvocationsAtLevel1 = LevelInputCount_1 / ItemsPerInvocation_1;
66
+ #define DEFINE_ASSIGN (TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
67
+ #define DEFINE_VIRTUAL_WG_T (ID) virtual_wg_t::ID
68
+ #define DEFINE_ITEMS_INVOC_T (ID) items_per_invoc_t::ID
69
+ #define DEFINE_CONFIG_T (ID) ID
70
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
71
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
72
+ #include "impl/arithmetic_config_def.hlsl"
73
+ #undef DEFINE_COND_VAL
74
+ #undef DEFINE_MPL_MAX_V
75
+ #undef DEFINE_CONFIG_T
76
+ #undef DEFINE_ITEMS_INVOC_T
77
+ #undef DEFINE_VIRTUAL_WG_T
78
+ #undef DEFINE_ASSIGN
69
79
70
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize-1 ,0 >::value;
71
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3 ,uint16_t,VirtualInvocationsAtLevel1,SubgroupSize>::value + __padding;
72
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize,0 >::value;
73
80
using ChannelStride = tuple<integral_constant<uint16_t,__padding>,integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >; // we don't use stride 0
74
81
75
- // user specified the shared mem size of Scalars
76
- NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1 ,uint16_t,
77
- 0 ,
78
- conditional_value<LevelCount==3 ,uint16_t,
79
- LevelInputCount_2+(SubgroupSize*ItemsPerInvocation_1)-1 ,
80
- 0
81
- >::value + LevelInputCount_1
82
- >::value;
82
+ static_assert (VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
83
+ static_assert (ItemsPerInvocation_2<=4 , "4 level scan would have been needed with this config!" );
83
84
85
+ #ifdef __HLSL_VERSION
84
86
static bool electLast ()
85
87
{
86
88
return glsl::gl_SubgroupInvocationID ()==SubgroupSize-1 ;
87
89
}
90
+ #endif
88
91
89
92
// gets a subgroupID as if each workgroup has (VirtualWorkgroupSize/SubgroupSize) subgroups
90
93
// each subgroup does work (VirtualWorkgroupSize/WorkgroupSize) times, the index denoted by workgroupInVirtualIndex
@@ -140,6 +143,88 @@ struct ArithmeticConfiguration
140
143
}
141
144
};
142
145
146
+ #ifndef __HLSL_VERSION
147
+ namespace impl
148
+ {
149
+ struct SVirtualWGSizeLog2
150
+ {
151
+ static SVirtualWGSizeLog2 create (const uint16_t _WorkgroupSizeLog2, const uint16_t _SubgroupSizeLog2)
152
+ {
153
+ SVirtualWGSizeLog2 retval;
154
+ #define DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
155
+ #define DEFINE_VIRTUAL_WG_T (ID) retval.ID
156
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
157
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
158
+ #include "impl/virtual_wg_size_def.hlsl"
159
+ #undef DEFINE_COND_VAL
160
+ #undef DEFINE_MPL_MAX_V
161
+ #undef DEFINE_VIRTUAL_WG_T
162
+ #undef DEFINE_ASSIGN
163
+ return retval;
164
+ }
165
+
166
+ #define DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
167
+ #include "impl/virtual_wg_size_def.hlsl"
168
+ #undef DEFINE_ASSIGN
169
+ };
170
+
171
+ struct SItemsPerInvoc
172
+ {
173
+ static SItemsPerInvoc create (const SVirtualWGSizeLog2 virtualWgSizeLog2, const uint16_t BaseItemsPerInvocation)
174
+ {
175
+ SItemsPerInvoc retval;
176
+ #define DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
177
+ #define DEFINE_VIRTUAL_WG_T (ID) virtualWgSizeLog2.ID
178
+ #define DEFINE_ITEMS_INVOC_T (ID) retval.ID
179
+ #define DEFINE_MPL_MIN_V (TYPE,ARG1,ARG2) hlsl::min <TYPE>(ARG1, ARG2)
180
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
181
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
182
+ #include "impl/items_per_invoc_def.hlsl"
183
+ #undef DEFINE_COND_VAL
184
+ #undef DEFINE_MPL_MAX_V
185
+ #undef DEFINE_MPL_MIN_V
186
+ #undef DEFINE_ITEMS_INVOC_T
187
+ #undef DEFINE_VIRTUAL_WG_T
188
+ #undef DEFINE_ASSIGN
189
+ return retval;
190
+ }
191
+
192
+ #define DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
193
+ #include "impl/items_per_invoc_def.hlsl"
194
+ #undef DEFINE_ASSIGN
195
+ };
196
+ }
197
+
198
+ struct SArithmeticConfiguration
199
+ {
200
+ static SArithmeticConfiguration create (const uint16_t _WorkgroupSizeLog2, const uint16_t _SubgroupSizeLog2, const uint16_t _ItemsPerInvocation)
201
+ {
202
+ impl::SVirtualWGSizeLog2 virtualWgSizeLog2 = impl::SVirtualWGSizeLog2::create (_WorkgroupSizeLog2, _SubgroupSizeLog2);
203
+ impl::SItemsPerInvoc itemsPerInvoc = impl::SItemsPerInvoc::create (virtualWgSizeLog2, _ItemsPerInvocation);
204
+
205
+ SArithmeticConfiguration retval;
206
+ #define DEFINE_ASSIGN (TYPE,ID,...) retval.ID = __VA_ARGS__;
207
+ #define DEFINE_VIRTUAL_WG_T (ID) virtualWgSizeLog2.ID
208
+ #define DEFINE_ITEMS_INVOC_T (ID) itemsPerInvoc.ID
209
+ #define DEFINE_CONFIG_T (ID) retval.ID
210
+ #define DEFINE_MPL_MAX_V (TYPE,ARG1,ARG2) hlsl::max <TYPE>(ARG1, ARG2)
211
+ #define DEFINE_COND_VAL (TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
212
+ #include "impl/arithmetic_config_def.hlsl"
213
+ #undef DEFINE_COND_VAL
214
+ #undef DEFINE_MPL_MAX_V
215
+ #undef DEFINE_CONFIG_T
216
+ #undef DEFINE_ITEMS_INVOC_T
217
+ #undef DEFINE_VIRTUAL_WG_T
218
+ #undef DEFINE_ASSIGN
219
+ return retval;
220
+ }
221
+
222
+ #define DEFINE_ASSIGN (TYPE,ID,...) TYPE ID;
223
+ #include "impl/arithmetic_config_def.hlsl"
224
+ #undef DEFINE_ASSIGN
225
+ };
226
+ #endif
227
+
143
228
template<class T>
144
229
struct is_configuration : bool_constant<false > {};
145
230
0 commit comments