Skip to content

Commit 437c194

Browse files
committed
use x-macros for config compat between hlsl and cpp
1 parent 10b7f50 commit 437c194

File tree

5 files changed

+176
-41
lines changed

5 files changed

+176
-41
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 125 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
88
#include "nbl/builtin/hlsl/tuple.hlsl"
9+
#include "nbl/builtin/hlsl/mpl.hlsl"
910

1011
namespace nbl
1112
{
@@ -19,23 +20,37 @@ namespace impl
1920
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2>
2021
struct virtual_wg_size_log2
2122
{
22-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
23-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
23+
#define DEFINE_ASSIGN(TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
24+
#define DEFINE_VIRTUAL_WG_T(ID) ID
25+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
26+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
27+
#include "impl/virtual_wg_size_def.hlsl"
28+
#undef DEFINE_COND_VAL
29+
#undef DEFINE_MPL_MAX_V
30+
#undef DEFINE_VIRTUAL_WG_T
31+
#undef DEFINE_ASSIGN
32+
33+
// must have at least enough level 0 outputs to feed a single subgroup
2434
static_assert(WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize");
2535
static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2*3+4, "WorkgroupSize cannot be larger than (SubgroupSize^3)*16");
26-
27-
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value;
28-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, SubgroupSizeLog2*levels, WorkgroupSizeLog2>;
29-
// must have at least enough level 0 outputs to feed a single subgroup
3036
};
3137

3238
template<class VirtualWorkgroup, uint16_t BaseItemsPerInvocation>
3339
struct items_per_invocation
3440
{
35-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v<int16_t,VirtualWorkgroup::WorkgroupSizeLog2-VirtualWorkgroup::SubgroupSizeLog2*VirtualWorkgroup::levels,0>;
36-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
37-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value;
38-
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>;
41+
#define DEFINE_ASSIGN(TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
42+
#define DEFINE_VIRTUAL_WG_T(ID) VirtualWorkgroup::ID
43+
#define DEFINE_ITEMS_INVOC_T(ID) ID
44+
#define DEFINE_MPL_MIN_V(TYPE,ARG1,ARG2) mpl::min_v<TYPE, ARG1, ARG2>
45+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
46+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
47+
#include "impl/items_per_invoc_def.hlsl"
48+
#undef DEFINE_COND_VAL
49+
#undef DEFINE_MPL_MAX_V
50+
#undef DEFINE_MPL_MIN_V
51+
#undef DEFINE_ITEMS_INVOC_T
52+
#undef DEFINE_VIRTUAL_WG_T
53+
#undef DEFINE_ASSIGN
3954

4055
using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >;
4156
};
@@ -44,47 +59,35 @@ struct items_per_invocation
4459
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
4560
struct ArithmeticConfiguration
4661
{
47-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
48-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2;
49-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
50-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;
51-
52-
using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
53-
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
54-
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value;
55-
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
56-
62+
using virtual_wg_t = impl::virtual_wg_size_log2<_WorkgroupSizeLog2, _SubgroupSizeLog2>;
5763
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
5864
using ItemsPerInvocation = typename items_per_invoc_t::ItemsPerInvocation;
59-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = tuple_element<0,ItemsPerInvocation>::type::value;
60-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = tuple_element<1,ItemsPerInvocation>::type::value;
61-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = tuple_element<2,ItemsPerInvocation>::type::value;
62-
static_assert(ItemsPerInvocation_2<=4, "4 level scan would have been needed with this config!");
6365

64-
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3,uint16_t,
65-
mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
66-
SubgroupSize*ItemsPerInvocation_1>::value;
67-
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value;
68-
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualInvocationsAtLevel1 = LevelInputCount_1 / ItemsPerInvocation_1;
66+
#define DEFINE_ASSIGN(TYPE,ID,...) NBL_CONSTEXPR_STATIC_INLINE TYPE ID = __VA_ARGS__;
67+
#define DEFINE_VIRTUAL_WG_T(ID) virtual_wg_t::ID
68+
#define DEFINE_ITEMS_INVOC_T(ID) items_per_invoc_t::ID
69+
#define DEFINE_CONFIG_T(ID) ID
70+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) mpl::max_v<TYPE, ARG1, ARG2>
71+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) conditional_value<COND,TYPE,TRUE_VAL,FALSE_VAL>::value
72+
#include "impl/arithmetic_config_def.hlsl"
73+
#undef DEFINE_COND_VAL
74+
#undef DEFINE_MPL_MAX_V
75+
#undef DEFINE_CONFIG_T
76+
#undef DEFINE_ITEMS_INVOC_T
77+
#undef DEFINE_VIRTUAL_WG_T
78+
#undef DEFINE_ASSIGN
6979

70-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3,uint16_t,SubgroupSize-1,0>::value;
71-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_1 = conditional_value<LevelCount==3,uint16_t,VirtualInvocationsAtLevel1,SubgroupSize>::value + __padding;
72-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __channelStride_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize,0>::value;
7380
using ChannelStride = tuple<integral_constant<uint16_t,__padding>,integral_constant<uint16_t,__channelStride_1>,integral_constant<uint16_t,__channelStride_2> >; // we don't use stride 0
7481

75-
// user specified the shared mem size of Scalars
76-
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
77-
0,
78-
conditional_value<LevelCount==3,uint16_t,
79-
LevelInputCount_2+(SubgroupSize*ItemsPerInvocation_1)-1,
80-
0
81-
>::value + LevelInputCount_1
82-
>::value;
82+
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
83+
static_assert(ItemsPerInvocation_2<=4, "4 level scan would have been needed with this config!");
8384

85+
#ifdef __HLSL_VERSION
8486
static bool electLast()
8587
{
8688
return glsl::gl_SubgroupInvocationID()==SubgroupSize-1;
8789
}
90+
#endif
8891

8992
// gets a subgroupID as if each workgroup has (VirtualWorkgroupSize/SubgroupSize) subgroups
9093
// each subgroup does work (VirtualWorkgroupSize/WorkgroupSize) times, the index denoted by workgroupInVirtualIndex
@@ -140,6 +143,88 @@ struct ArithmeticConfiguration
140143
}
141144
};
142145

146+
#ifndef __HLSL_VERSION
147+
namespace impl
148+
{
149+
struct SVirtualWGSizeLog2
150+
{
151+
static SVirtualWGSizeLog2 create(const uint16_t _WorkgroupSizeLog2, const uint16_t _SubgroupSizeLog2)
152+
{
153+
SVirtualWGSizeLog2 retval;
154+
#define DEFINE_ASSIGN(TYPE,ID,...) retval.ID = __VA_ARGS__;
155+
#define DEFINE_VIRTUAL_WG_T(ID) retval.ID
156+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) hlsl::max<TYPE>(ARG1, ARG2)
157+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
158+
#include "impl/virtual_wg_size_def.hlsl"
159+
#undef DEFINE_COND_VAL
160+
#undef DEFINE_MPL_MAX_V
161+
#undef DEFINE_VIRTUAL_WG_T
162+
#undef DEFINE_ASSIGN
163+
return retval;
164+
}
165+
166+
#define DEFINE_ASSIGN(TYPE,ID,...) TYPE ID;
167+
#include "impl/virtual_wg_size_def.hlsl"
168+
#undef DEFINE_ASSIGN
169+
};
170+
171+
struct SItemsPerInvoc
172+
{
173+
static SItemsPerInvoc create(const SVirtualWGSizeLog2 virtualWgSizeLog2, const uint16_t BaseItemsPerInvocation)
174+
{
175+
SItemsPerInvoc retval;
176+
#define DEFINE_ASSIGN(TYPE,ID,...) retval.ID = __VA_ARGS__;
177+
#define DEFINE_VIRTUAL_WG_T(ID) virtualWgSizeLog2.ID
178+
#define DEFINE_ITEMS_INVOC_T(ID) retval.ID
179+
#define DEFINE_MPL_MIN_V(TYPE,ARG1,ARG2) hlsl::min<TYPE>(ARG1, ARG2)
180+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) hlsl::max<TYPE>(ARG1, ARG2)
181+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
182+
#include "impl/items_per_invoc_def.hlsl"
183+
#undef DEFINE_COND_VAL
184+
#undef DEFINE_MPL_MAX_V
185+
#undef DEFINE_MPL_MIN_V
186+
#undef DEFINE_ITEMS_INVOC_T
187+
#undef DEFINE_VIRTUAL_WG_T
188+
#undef DEFINE_ASSIGN
189+
return retval;
190+
}
191+
192+
#define DEFINE_ASSIGN(TYPE,ID,...) TYPE ID;
193+
#include "impl/items_per_invoc_def.hlsl"
194+
#undef DEFINE_ASSIGN
195+
};
196+
}
197+
198+
struct SArithmeticConfiguration
199+
{
200+
static SArithmeticConfiguration create(const uint16_t _WorkgroupSizeLog2, const uint16_t _SubgroupSizeLog2, const uint16_t _ItemsPerInvocation)
201+
{
202+
impl::SVirtualWGSizeLog2 virtualWgSizeLog2 = impl::SVirtualWGSizeLog2::create(_WorkgroupSizeLog2, _SubgroupSizeLog2);
203+
impl::SItemsPerInvoc itemsPerInvoc = impl::SItemsPerInvoc::create(virtualWgSizeLog2, _ItemsPerInvocation);
204+
205+
SArithmeticConfiguration retval;
206+
#define DEFINE_ASSIGN(TYPE,ID,...) retval.ID = __VA_ARGS__;
207+
#define DEFINE_VIRTUAL_WG_T(ID) virtualWgSizeLog2.ID
208+
#define DEFINE_ITEMS_INVOC_T(ID) itemsPerInvoc.ID
209+
#define DEFINE_CONFIG_T(ID) retval.ID
210+
#define DEFINE_MPL_MAX_V(TYPE,ARG1,ARG2) hlsl::max<TYPE>(ARG1, ARG2)
211+
#define DEFINE_COND_VAL(TYPE,COND,TRUE_VAL,FALSE_VAL) (COND ? TRUE_VAL : FALSE_VAL)
212+
#include "impl/arithmetic_config_def.hlsl"
213+
#undef DEFINE_COND_VAL
214+
#undef DEFINE_MPL_MAX_V
215+
#undef DEFINE_CONFIG_T
216+
#undef DEFINE_ITEMS_INVOC_T
217+
#undef DEFINE_VIRTUAL_WG_T
218+
#undef DEFINE_ASSIGN
219+
return retval;
220+
}
221+
222+
#define DEFINE_ASSIGN(TYPE,ID,...) TYPE ID;
223+
#include "impl/arithmetic_config_def.hlsl"
224+
#undef DEFINE_ASSIGN
225+
};
226+
#endif
227+
143228
template<class T>
144229
struct is_configuration : bool_constant<false> {};
145230

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
5+
DEFINE_ASSIGN(uint16_t, WorkgroupSizeLog2, _WorkgroupSizeLog2)
6+
DEFINE_ASSIGN(uint16_t, WorkgroupSize, uint16_t(0x1u) << DEFINE_CONFIG_T(WorkgroupSizeLog2))
7+
DEFINE_ASSIGN(uint16_t, SubgroupSizeLog2, _SubgroupSizeLog2)
8+
DEFINE_ASSIGN(uint16_t, SubgroupSize, uint16_t(0x1u) << DEFINE_CONFIG_T(SubgroupSizeLog2))
9+
10+
DEFINE_ASSIGN(uint16_t, LevelCount, DEFINE_VIRTUAL_WG_T(levels))
11+
DEFINE_ASSIGN(uint16_t, VirtualWorkgroupSize, uint16_t(0x1u) << DEFINE_VIRTUAL_WG_T(value))
12+
13+
DEFINE_ASSIGN(uint16_t, ItemsPerInvocation_0, DEFINE_ITEMS_INVOC_T(value0))
14+
DEFINE_ASSIGN(uint16_t, ItemsPerInvocation_1, DEFINE_ITEMS_INVOC_T(value1))
15+
DEFINE_ASSIGN(uint16_t, ItemsPerInvocation_2, DEFINE_ITEMS_INVOC_T(value2))
16+
17+
DEFINE_ASSIGN(uint16_t, LevelInputCount_1, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),
18+
DEFINE_MPL_MAX_V(uint16_t, (DEFINE_CONFIG_T(VirtualWorkgroupSize)>>DEFINE_CONFIG_T(SubgroupSizeLog2)), DEFINE_CONFIG_T(SubgroupSize)),
19+
DEFINE_CONFIG_T(SubgroupSize)*DEFINE_CONFIG_T(ItemsPerInvocation_1)))
20+
DEFINE_ASSIGN(uint16_t, LevelInputCount_2, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),DEFINE_CONFIG_T(SubgroupSize)*DEFINE_CONFIG_T(ItemsPerInvocation_2),0))
21+
DEFINE_ASSIGN(uint16_t, VirtualInvocationsAtLevel1, DEFINE_CONFIG_T(LevelInputCount_1) / DEFINE_CONFIG_T(ItemsPerInvocation_1))
22+
23+
DEFINE_ASSIGN(uint16_t, __padding, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),DEFINE_CONFIG_T(SubgroupSize)-1,0))
24+
DEFINE_ASSIGN(uint16_t, __channelStride_1, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),DEFINE_CONFIG_T(VirtualInvocationsAtLevel1),DEFINE_CONFIG_T(SubgroupSize)) + DEFINE_CONFIG_T(__padding))
25+
DEFINE_ASSIGN(uint16_t, __channelStride_2, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),DEFINE_CONFIG_T(SubgroupSize),0))
26+
27+
// user specified the shared mem size of Scalars
28+
DEFINE_ASSIGN(uint32_t, SharedScratchElementCount, DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==1),
29+
0,
30+
DEFINE_COND_VAL(uint16_t,(DEFINE_CONFIG_T(LevelCount)==3),
31+
DEFINE_CONFIG_T(LevelInputCount_2)+(DEFINE_CONFIG_T(SubgroupSize)*DEFINE_CONFIG_T(ItemsPerInvocation_1))-1,
32+
0
33+
) + DEFINE_CONFIG_T(LevelInputCount_1)
34+
))
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
5+
DEFINE_ASSIGN(uint16_t, ItemsPerInvocationProductLog2, DEFINE_MPL_MAX_V(int16_t,DEFINE_VIRTUAL_WG_T(WorkgroupSizeLog2)-DEFINE_VIRTUAL_WG_T(SubgroupSizeLog2)*DEFINE_VIRTUAL_WG_T(levels),0))
6+
DEFINE_ASSIGN(uint16_t, value0, BaseItemsPerInvocation)
7+
DEFINE_ASSIGN(uint16_t, value1, uint16_t(0x1u) << DEFINE_COND_VAL(uint16_t,(DEFINE_VIRTUAL_WG_T(levels)==3),DEFINE_MPL_MIN_V(uint16_t,DEFINE_ITEMS_INVOC_T(ItemsPerInvocationProductLog2),2),DEFINE_ITEMS_INVOC_T(ItemsPerInvocationProductLog2)))
8+
DEFINE_ASSIGN(uint16_t, value2, uint16_t(0x1u) << DEFINE_MPL_MAX_V(int16_t,DEFINE_ITEMS_INVOC_T(ItemsPerInvocationProductLog2)-2,0))
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
5+
DEFINE_ASSIGN(uint16_t, WorkgroupSizeLog2, _WorkgroupSizeLog2)
6+
DEFINE_ASSIGN(uint16_t, SubgroupSizeLog2, _SubgroupSizeLog2)
7+
DEFINE_ASSIGN(uint16_t, levels, DEFINE_COND_VAL(uint16_t,(_WorkgroupSizeLog2>_SubgroupSizeLog2),DEFINE_COND_VAL(uint16_t,(_WorkgroupSizeLog2>_SubgroupSizeLog2*2+2),3,2),1))
8+
DEFINE_ASSIGN(uint16_t, value, DEFINE_MPL_MAX_V(uint16_t, _SubgroupSizeLog2*DEFINE_VIRTUAL_WG_T(levels), _WorkgroupSizeLog2))

0 commit comments

Comments
 (0)