diff --git a/slangpy/tests/device/slang/test_nested_structs.py b/slangpy/tests/device/slang/test_nested_structs.py index c130e2ee..86da2f56 100644 --- a/slangpy/tests/device/slang/test_nested_structs.py +++ b/slangpy/tests/device/slang/test_nested_structs.py @@ -12,10 +12,6 @@ @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_nested_structs(device_type: spy.DeviceType): - if device_type in [spy.DeviceType.cuda, spy.DeviceType.metal]: - pytest.skip( - "bool is currently not handled correctly on CUDA/Metal, see issue: https://github.com/shader-slang/slangpy/issues/274" - ) device = helpers.get_device(device_type) program = device.load_program("slang/test_nested_structs.slang", ["compute_main"]) diff --git a/slangpy/tests/device/test_buffer_cursor.py b/slangpy/tests/device/test_buffer_cursor.py index 9a010fc5..a419c135 100644 --- a/slangpy/tests/device/test_buffer_cursor.py +++ b/slangpy/tests/device/test_buffer_cursor.py @@ -72,6 +72,12 @@ "float3x4(1.0, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0)", spy.float3x4([1.0, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0]), ), + ( + "f_float4x3", + "float4x3", + "float4x3(1.0, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0)", + spy.float4x3([1.0, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0]), + ), ( "f_float4x4", "float4x4", @@ -231,13 +237,8 @@ ] -# Filter out all bool tests for CUDA/Metal backend, as it is not handled correct. See issue: -# https://github.com/shader-slang/slangpy/issues/274 def get_tests(device_type: spy.DeviceType): - if device_type not in [spy.DeviceType.cuda, spy.DeviceType.metal]: - return TESTS - tests = [x for x in TESTS if "bool" not in x[0]] - return tests + return TESTS def variable_decls(tests: list[Any]): @@ -531,6 +532,178 @@ def test_apply_changes(device_type: spy.DeviceType, seed: int): check_match(test, element[name].read()) +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +@pytest.mark.parametrize("seed", RAND_SEEDS) +@pytest.mark.parametrize("element_class", [np.array, spy.bool2, tuple, list]) +def test_bool_buffers(device_type: spy.DeviceType, seed: int, element_class: Any): + code = f""" + [shader("compute")] + [numthreads(1, 1, 1)] + void compute_main(uint3 tid: SV_DispatchThreadID, StructuredBuffer src, RWStructuredBuffer dest) {{ + uint i = tid.x; + dest[i] = src[i]; + }} + """ + mod_name = ( + "test_buffer_cursor_TestBoolBuffers_" + hashlib.sha256(code.encode()).hexdigest()[0:8] + ) + device = helpers.get_device(device_type) + module = device.load_module_from_source(mod_name, code) + prog = device.link_program([module], [module.entry_point("compute_main")]) + buffer_layout = module.layout.get_type_layout( + module.layout.find_type_by_name("StructuredBuffer") + ) + (kernel, buffer_layout) = (device.create_compute_kernel(prog), buffer_layout) + + # Make a buffer with 128 elements and a cursor to wrap it + count = 128 + src = kernel.device.create_buffer( + element_count=count, + struct_type=buffer_layout, + usage=spy.BufferUsage.shader_resource | spy.BufferUsage.unordered_access, + data=np.zeros(buffer_layout.element_type_layout.stride * count, dtype=np.uint8), + ) + dest = kernel.device.create_buffer( + element_count=count, + struct_type=buffer_layout, + usage=spy.BufferUsage.shader_resource | spy.BufferUsage.unordered_access, + data=np.zeros(buffer_layout.element_type_layout.stride * count, dtype=np.uint8), + ) + src_cursor = spy.BufferCursor(buffer_layout.element_type_layout, src) + dest_cursor = spy.BufferCursor(buffer_layout.element_type_layout, dest) + + random.seed(seed) + list_data = [[random.randint(0, 1) == 1, random.randint(0, 1) == 1] for i in range(count)] + data = [] + if element_class == np.array: + data = [element_class(x, dtype=np.bool_) for x in list_data] + elif element_class == spy.bool2: + data = [element_class(x) for x in list_data] + elif element_class == tuple: + data = [(x[0], x[1]) for x in list_data] + elif element_class == list: + data = list_data + + for i in range(count): + src_cursor[i].write(data[i]) + + # Apply changes to source + src_cursor.apply() + + # Dispatch the kernel + kernel.dispatch([count, 1, 1], src=src, dest=dest) + + dest_cursor.load() + for i in range(count): + result = dest_cursor[i].read() + data_ref = spy.bool2(list_data[i]) + src_ref = src_cursor[i].read() + assert result == data_ref + assert result == src_ref + + +# test introduced to warn us when issue https://github.com/shader-slang/slang/issues/7441 +# has been resolved and the type information or the underlying types have changed. +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_boolX_reflection(device_type: spy.DeviceType): + code = f""" + [shader("compute")] + [numthreads(1, 1, 1)] + void compute_main(uint3 tid: SV_DispatchThreadID, StructuredBuffer src, RWStructuredBuffer dest) {{ + uint i = tid.x; + dest[i] = src[i]; + }} + """ + mod_name = ( + "test_buffer_cursor_test_boolX_reflection_" + hashlib.sha256(code.encode()).hexdigest()[0:8] + ) + device = helpers.get_device(device_type) + module = device.load_module_from_source(mod_name, code) + prog = device.link_program([module], [module.entry_point("compute_main")]) + sb_bool2_layout = module.layout.get_type_layout( + module.layout.find_type_by_name("StructuredBuffer") + ) + pb_bool2_layout = module.layout.get_type_layout( + module.layout.find_type_by_name("ParameterBlock") + ) + u_bool2_layout = module.layout.get_type_layout(module.layout.find_type_by_name("bool2")) + + sb_bool2_element_layout = sb_bool2_layout.element_type_layout + pb_bool2_element_layout = pb_bool2_layout.element_type_layout + + def make_layout(type_layout: spy.TypeLayoutReflection): + return { + "size": type_layout.size, + "stride": type_layout.size, + "element_stride": type_layout.element_stride(), + "element_type_layout.size": type_layout.element_type_layout.size, + "element_type_layout.stride": type_layout.element_type_layout.stride, + } + + def make_layout_ref(): + if device_type == spy.DeviceType.d3d12: + return { + "size": 8, + "stride": 8, + "element_stride": 4, + "element_type_layout.size": 4, + "element_type_layout.stride": 4, + } + if device_type == spy.DeviceType.vulkan: + return { + "size": 8, + "stride": 8, + "element_stride": 4, + "element_type_layout.size": 4, + "element_type_layout.stride": 4, + } + if device_type == spy.DeviceType.metal: + return { + "size": 2, + "stride": 2, + "element_stride": 1, + "element_type_layout.size": 1, + "element_type_layout.stride": 1, + } + if device_type == spy.DeviceType.wgpu: + return { + "size": 8, + "stride": 8, + "element_stride": 4, + "element_type_layout.size": 4, + "element_type_layout.stride": 4, + } + if device_type == spy.DeviceType.cpu: + return { + "size": 2, + "stride": 2, + "element_stride": 1, + "element_type_layout.size": 1, + "element_type_layout.stride": 1, + } + # This is actually reporting wrong, see issue: https://github.com/shader-slang/slang/issues/7441 + # Once that issue has been resolved, this test should trigger and workarounds can be removed + if device_type == spy.DeviceType.cuda: + return { + "size": 8, + "stride": 8, + "element_stride": 1, + "element_type_layout.size": 1, + "element_type_layout.stride": 1, + } + + layout_descs = { + "u_bool2": make_layout(u_bool2_layout), + "sb_bool2_element": make_layout(sb_bool2_element_layout), + "pb_bool2_element": make_layout(pb_bool2_element_layout), + } + + ref_desc = make_layout_ref() + + for k, v in layout_descs.items(): + assert v == ref_desc + + @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @pytest.mark.parametrize("seed", RAND_SEEDS) def test_apply_changes_ndarray(device_type: spy.DeviceType, seed: int): diff --git a/slangpy/tests/device/test_shader_cursor.py b/slangpy/tests/device/test_shader_cursor.py index 19a7f31f..49d13c57 100644 --- a/slangpy/tests/device/test_shader_cursor.py +++ b/slangpy/tests/device/test_shader_cursor.py @@ -35,7 +35,7 @@ class TypeInfo: TYPE_INFOS = { - "bool": TypeInfo(size=4, struct="I", dtype=np.uint32), # np.bool is 8 bits + "bool": TypeInfo(size=4, struct="I", dtype=np.bool_), # np.bool is 8 bits "int": TypeInfo(size=4, struct="i", dtype=np.int32), "uint": TypeInfo(size=4, struct="I", dtype=np.uint32), "float": TypeInfo(size=4, struct="f", dtype=np.float32), @@ -49,10 +49,7 @@ class TypeInfo: def get_type_info(device_type: spy.DeviceType, type: str): - if device_type != spy.DeviceType.cuda or type != "bool": - return TYPE_INFOS[type] - # CUDA bool is size 1 - TypeInfo(size=1, struct="I", dtype=np.uint32) + return TYPE_INFOS[type] @dataclass @@ -321,10 +318,7 @@ def write_var( sizes.append(size) references.append(struct.pack(struct_pattern, *flat_value).hex()) - # CUDA/Metal have bool size of 1, which is currently not handled, see issue: - # https://github.com/shader-slang/slangpy/issues/274 - if device_type not in [spy.DeviceType.cuda, spy.DeviceType.metal] or var.type != "bool": - cursor[name_or_index] = value + cursor[name_or_index] = value def write_vars( device_type: spy.DeviceType, @@ -375,13 +369,7 @@ def write_vars( named_typed_result[0] == "u_float2x2" or named_typed_result[0] == "u_float3x3" ): continue - # CUDA/Metal have bool size of 1, which is currently not handled, see issue: - # https://github.com/shader-slang/slangpy/issues/274 - if ( - device_type in [spy.DeviceType.cuda, spy.DeviceType.metal] - and named_typed_result[1] == "bool" - ): - continue + assert named_typed_result == named_typed_reference diff --git a/slangpy/tests/device/test_shader_cursor.slang b/slangpy/tests/device/test_shader_cursor.slang index c8611397..32e2f54e 100644 --- a/slangpy/tests/device/test_shader_cursor.slang +++ b/slangpy/tests/device/test_shader_cursor.slang @@ -154,7 +154,7 @@ extension bool : IWritable { void write(inout Writer writer) { - writer.buffer[writer.offset++] = asuint(this); + writer.buffer[writer.offset++] = (this) ? 1u : 0u; } } diff --git a/src/sgl/device/buffer_cursor.cpp b/src/sgl/device/buffer_cursor.cpp index 2ed84e7f..6ab3a137 100644 --- a/src/sgl/device/buffer_cursor.cpp +++ b/src/sgl/device/buffer_cursor.cpp @@ -11,6 +11,7 @@ #include "sgl/math/vector_types.h" #include "sgl/math/matrix_types.h" +#include "sgl/device/device.h" namespace sgl { @@ -108,6 +109,146 @@ void BufferElementCursor::set_data(const void* data, size_t size) write_data(m_offset, data, size); } +template +void BufferElementCursor::write_data( + size_t dst_offset, + size_t dst_stride, + const void* src_data, + size_t src_stride, + size_t dimension, + Func&& convert +) +{ + const uint8_t* src_ptr = reinterpret_cast(src_data); + for (size_t i = 0; i < dimension; ++i, dst_offset += dst_stride, src_ptr += src_stride) { + const TDst& dst = convert(*reinterpret_cast(src_ptr)); + write_data(dst_offset, &dst, sizeof(dst)); + } +} + +template +void BufferElementCursor::read_data( + void* dst_data, + size_t dst_stride, + size_t src_offset, + size_t src_stride, + size_t dimension, + Func&& convert +) const +{ + uint8_t* dst_ptr = reinterpret_cast(dst_data); + TSrc src; + for (size_t i = 0; i < dimension; ++i, dst_ptr += dst_stride, src_offset += src_stride) { + read_data(src_offset, &src, sizeof(TSrc)); + *reinterpret_cast(dst_ptr) = convert(src); + } +} + +void BufferElementCursor::set_bool_array(const void* data, size_t src_size, size_t element_count) +{ + cursor_utils::check_array( + m_type_layout->slang_target(), + src_size, + TypeReflection::ScalarType::bool_, + element_count + ); + + size_t src_element_size = src_size / element_count; + size_t src_element_stride = src_element_size; + size_t dst_element_size = m_type_layout->slang_target()->getElementTypeLayout()->getSize(); + size_t dst_element_stride = m_type_layout->element_stride(); + + if (m_type_layout->slang_target()->getSize() == src_size) { + write_data(m_offset, &data, src_size); + return; + } + + if (src_element_size == 1) // cpu bool + { + if (dst_element_size == 4) // d4d12, vulkan + { + write_data( + m_offset, + dst_element_stride, + data, + src_element_stride, + element_count, + [](bool b) -> uint32_t { return b ? 1 : 0; } + ); + return; + } else if (dst_element_size == 1) // cuda, metal + { + write_data( + m_offset, + dst_element_stride, + data, + src_element_stride, + element_count, + [](bool b) -> uint8_t { return b ? 1 : 0; } + ); + return; + } + } + SGL_THROW( + "Invalid configuration of bool array write, source is {}B, device is {}B.", + src_element_size, + dst_element_size + ); +} + +void BufferElementCursor::get_bool_array(void* dst_data, size_t dst_size, size_t element_count) const +{ + cursor_utils::check_array( + m_type_layout->slang_target(), + dst_size, + TypeReflection::ScalarType::bool_, + element_count + ); + + size_t src_element_size = m_type_layout->element_type_layout()->slang_target()->getSize(); + size_t src_element_stride = m_type_layout->element_stride(); + size_t dst_element_size = dst_size / element_count; + size_t dst_element_stride = dst_element_size; + + if (dst_size == m_type_layout->slang_target()->getSize()) { + read_data(m_offset, dst_data, dst_size); + return; + } + + if (dst_element_size == 1) // cpu bool + { + if (src_element_size == 4) // d4d12, vulkan + { + read_data( + dst_data, + dst_element_stride, + m_offset, + src_element_stride, + element_count, + [](uint32_t b) { return b != 0; } + ); + return; + } else if (src_element_size == 1) // cuda, metal + { + read_data( + dst_data, + dst_element_stride, + m_offset, + src_element_stride, + element_count, + [](uint8_t b) { return b != 0; } + ); + return; + } + } + SGL_THROW( + "Invalid configuration of bool array write, source is {}B, device is {}B.", + src_element_size, + dst_element_size + ); +} + + void BufferElementCursor::_set_array( const void* data, size_t size, @@ -115,8 +256,11 @@ void BufferElementCursor::_set_array( size_t element_count ) { - ref element_type = m_type_layout->unwrap_array()->type(); - size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type()); + if (scalar_type == TypeReflection::ScalarType::bool_) + return set_bool_array(data, size, element_count); + + ref element_type_layout = m_type_layout->unwrap_array(); + size_t element_size = element_type_layout->slang_target()->getSize(); cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count); @@ -139,8 +283,11 @@ void BufferElementCursor::_get_array( size_t element_count ) const { - ref element_type = m_type_layout->unwrap_array()->type(); - size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type()); + if (scalar_type == TypeReflection::ScalarType::bool_) + return get_bool_array(data, size, element_count); + + ref element_type_layout = m_type_layout->unwrap_array(); + size_t element_size = element_type_layout->slang_target()->getSize(); cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count); @@ -194,16 +341,9 @@ void BufferElementCursor::_set_matrix( ) { cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols); - size_t stride = slang_type_layout()->getStride(); - if (stride != size) { - size_t row_stride = stride / rows; - size_t row_size = size / rows; - for (int i = 0; i < rows; ++i) { - write_data(m_offset + i * row_stride, reinterpret_cast(data) + i * row_size, row_size); - } - } else { - write_data(m_offset, data, size); - } + // In Buffer, we should always be tightly packed, the float4x3 padding is for paramblocks and constant buffers. + SGL_ASSERT(size == slang_type_layout()->getStride()); + write_data(m_offset, data, size); } void BufferElementCursor::_get_matrix( @@ -215,16 +355,9 @@ void BufferElementCursor::_get_matrix( ) const { cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols); - size_t stride = slang_type_layout()->getStride(); - if (stride != size) { - size_t row_stride = stride / rows; - size_t row_size = size / rows; - for (int i = 0; i < rows; ++i) { - read_data(m_offset + i * row_stride, reinterpret_cast(data) + i * row_size, row_size); - } - } else { - read_data(m_offset, data, size); - } + // In Buffer, we should always be tightly packed, the float4x3 padding is for paramblocks and constant buffers. + SGL_ASSERT(size == slang_type_layout()->getStride()); + read_data(m_offset, data, size); } @@ -343,95 +476,104 @@ SGL_API void BufferElementCursor::get(bool& value) const value = v != 0; } +template +void BufferElementCursor::set_boolN(const sgl::math::vector& value) +{ + /// Workaround for issue: https://github.com/shader-slang/slang/issues/7441 + if (m_buffer->resource()->device()->type() == DeviceType::cuda) { + sgl::math::vector v; + for (int i = 0; i < N; ++i) + v[i] = value[i] ? 1 : 0; + set_data(&v, sizeof(v)); + return; + } + + if (slang_type_layout()->getElementTypeLayout()->getSize() == 1) { + SGL_ASSERT_GE(slang_type_layout()->getSize(), sizeof(value)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), 1); + _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, N); + return; + } + sgl::math::vector v; + for (int i = 0; i < N; ++i) + v[i] = value[i] ? 1 : 0; + SGL_ASSERT_GE(slang_type_layout()->getStride(), sizeof(v)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), sizeof(uint32_t)); + _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, N); +} + +template +void BufferElementCursor::get_boolN(sgl::math::vector& value) const +{ + /// Workaround for issue: https://github.com/shader-slang/slang/issues/7441 + if (m_buffer->resource()->device()->type() == DeviceType::cuda) { + sgl::math::vector v; + read_data(m_offset, &v, sizeof(v)); + for (int i = 0; i < N; ++i) + value[i] = v[i] != 0; + return; + } + + if (slang_type_layout()->getElementTypeLayout()->getSize() == 1) { + SGL_ASSERT_GE(slang_type_layout()->getSize(), sizeof(value)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), 1); + _get_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, N); + return; + } + + sgl::math::vector v; + SGL_ASSERT_GE(slang_type_layout()->getStride(), sizeof(v)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), sizeof(uint32_t)); + _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, N); + for (int i = 0; i < N; ++i) + value[i] = v[i] != 0; +} + template<> SGL_API void BufferElementCursor::set(const bool1& value) { -#if SGL_MACOS - bool1 v = value; -#else - uint1 v(value.x ? 1 : 0); -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); + set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool1& value) const { -#if SGL_MACOS - bool1 v; -#else - uint1 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); - value = bool1(v.x != 0); + get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool2& value) { -#if SGL_MACOS - bool2 v = value; -#else - uint2 v = {value.x ? 1 : 0, value.y ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); + set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool2& value) const { -#if SGL_MACOS - bool2 v; -#else - uint2 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); - value = {v.x != 0, v.y != 0}; + get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool3& value) { -#if SGL_MACOS - bool3 v = value; -#else - uint3 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); + set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool3& value) const { -#if SGL_MACOS - bool3 v; -#else - uint3 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); - value = {v.x != 0, v.y != 0, v.z != 0}; + get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool4& value) { -#if SGL_MACOS - bool4 v = value; -#else - uint4 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0, value.w ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); + set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool4& value) const { -#if SGL_MACOS - bool4 v; -#else - uint4 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); - value = {v.x != 0, v.y != 0, v.z != 0, v.w != 0}; + get_boolN(value); } template<> diff --git a/src/sgl/device/buffer_cursor.h b/src/sgl/device/buffer_cursor.h index 24e4b478..fef1fac5 100644 --- a/src/sgl/device/buffer_cursor.h +++ b/src/sgl/device/buffer_cursor.h @@ -81,6 +81,29 @@ class SGL_API BufferElementCursor { void write_data(size_t offset, const void* data, size_t size); void read_data(size_t offset, void* data, size_t size) const; + void set_bool_array(const void* data, size_t size, size_t element_count); + void get_bool_array(void* data, size_t size, size_t element_count) const; + + template + void set_boolN(const sgl::math::vector& value); + template + void get_boolN(sgl::math::vector& value) const; + + template + void write_data( + size_t dst_offset, + size_t dst_stride, + const void* src_data, + size_t src_stride, + size_t dimension, + Func&& convert + ); + + template + void + read_data(void* dst_data, size_t dst_stride, size_t src_offset, size_t src_stride, size_t dimension, Func&& convert) + const; + ref m_type_layout; ref m_buffer; size_t m_offset{0}; diff --git a/src/sgl/device/cursor_utils.cpp b/src/sgl/device/cursor_utils.cpp index c9eddb21..30f31a8f 100644 --- a/src/sgl/device/cursor_utils.cpp +++ b/src/sgl/device/cursor_utils.cpp @@ -23,12 +23,13 @@ namespace cursor_utils { }; using ST = TypeReflection::ScalarType; + // bool can be converted to 32 or 8b int types, but needs additional size check. add_conversion(ST::int32, ST::uint32, ST::bool_); add_conversion(ST::uint32, ST::int32, ST::bool_); add_conversion(ST::int64, ST::uint64); add_conversion(ST::uint64, ST::int64); - add_conversion(ST::int8, ST::uint8); - add_conversion(ST::uint8, ST::int8); + add_conversion(ST::int8, ST::uint8, ST::bool_); + add_conversion(ST::uint8, ST::int8, ST::bool_); add_conversion(ST::int16, ST::uint16); add_conversion(ST::uint16, ST::int16); } @@ -47,30 +48,6 @@ namespace cursor_utils { return table.allow_conversion(from, to); } - size_t get_scalar_type_size(TypeReflection::ScalarType type) - { - switch (type) { - case TypeReflection::ScalarType::int8: - case TypeReflection::ScalarType::uint8: - return 1; - case TypeReflection::ScalarType::int16: - case TypeReflection::ScalarType::uint16: - case TypeReflection::ScalarType::float16: - return 2; - case TypeReflection::ScalarType::bool_: - case TypeReflection::ScalarType::int32: - case TypeReflection::ScalarType::uint32: - case TypeReflection::ScalarType::float32: - return 4; - case TypeReflection::ScalarType::int64: - case TypeReflection::ScalarType::uint64: - case TypeReflection::ScalarType::float64: - return 8; - default: - return 0; - } - } - slang::TypeLayoutReflection* unwrap_array(slang::TypeLayoutReflection* layout) { while (layout->isArray()) { @@ -80,136 +57,171 @@ namespace cursor_utils { } void check_array( - slang::TypeLayoutReflection* type_layout, - size_t size, - TypeReflection::ScalarType scalar_type, - size_t element_count + slang::TypeLayoutReflection* dst_type_layout, + size_t src_size, + TypeReflection::ScalarType src_scalar_type, + size_t src_element_count ) { - slang::TypeReflection* type = type_layout->getType(); - slang::TypeReflection* element_type = unwrap_array(type_layout)->getType(); - size_t element_size = get_scalar_type_size((TypeReflection::ScalarType)element_type->getScalarType()); + slang::TypeLayoutReflection* element_type_layout = unwrap_array(dst_type_layout); + size_t dst_element_size = element_type_layout->getSize(); + size_t src_element_size = src_size / src_element_count; - SGL_CHECK(type->isArray(), "\"{}\" cannot bind an array", type_layout->getName()); + SGL_CHECK(dst_type_layout->isArray(), "\"{}\" cannot bind an array", dst_type_layout->getName()); SGL_CHECK( - allow_scalar_conversion(scalar_type, (TypeReflection::ScalarType)element_type->getScalarType()), - "\"{}\" expects scalar type {} (no implicit conversion from type {})", - type_layout->getName(), - (TypeReflection::ScalarType)element_type->getScalarType(), - scalar_type + dst_element_size == src_element_size + && allow_scalar_conversion( + src_scalar_type, + (TypeReflection::ScalarType)element_type_layout->getScalarType() + ), + "\"{}\" expects scalar type {} ({}B) (no implicit conversion from type {} ({}B))", + dst_type_layout->getName(), + (TypeReflection::ScalarType)element_type_layout->getScalarType(), + dst_element_size, + src_scalar_type, + src_element_size ); SGL_CHECK( - element_count <= type->getElementCount(), + src_element_count == dst_type_layout->getElementCount(), "\"{}\" expects an array with at most {} elements (got {})", - type_layout->getName(), - type->getElementCount(), - element_count + dst_type_layout->getName(), + dst_type_layout->getElementCount(), + src_element_count ); - SGL_ASSERT(element_count * element_size == size); + SGL_ASSERT(src_element_count * dst_element_size == src_size); } - void check_scalar(slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type) + void check_scalar( + slang::TypeLayoutReflection* dst_type_layout, + size_t src_size, + TypeReflection::ScalarType src_scalar_type + ) { - slang::TypeReflection* type = unwrap_array(type_layout)->getType(); + size_t dst_size = dst_type_layout->getSize(); SGL_CHECK( - (TypeReflection::Kind)type->getKind() == TypeReflection::Kind::scalar, + (TypeReflection::Kind)dst_type_layout->getKind() == TypeReflection::Kind::scalar, "\"{}\" cannot bind a scalar value", - type_layout->getName() + dst_type_layout->getName() ); SGL_CHECK( - allow_scalar_conversion(scalar_type, (TypeReflection::ScalarType)type->getScalarType()), - "\"{}\" expects scalar type {} (no implicit conversion from type {})", - type_layout->getName(), - (TypeReflection::ScalarType)type->getScalarType(), - scalar_type + dst_size == src_size + && allow_scalar_conversion( + src_scalar_type, + (TypeReflection::ScalarType)dst_type_layout->getScalarType() + ), + "\"{}\" expects scalar type {} ({}B) (no implicit conversion from type {} ({}B))", + dst_type_layout->getName(), + (TypeReflection::ScalarType)dst_type_layout->getScalarType(), + dst_size, + src_scalar_type, + src_size ); SGL_CHECK( - type_layout->getSize() >= size, + src_size <= dst_type_layout->getSize(), "Mismatched size, writing {} B into backend type ({}) of only {} B.", - size, - type_layout->getName(), - type_layout->getSize() + src_size, + dst_type_layout->getName(), + dst_type_layout->getSize() ); } void check_vector( - slang::TypeLayoutReflection* type_layout, - size_t size, - TypeReflection::ScalarType scalar_type, - int dimension + slang::TypeLayoutReflection* dst_type_layout, + size_t src_size, + TypeReflection::ScalarType src_scalar_type, + int src_dimension ) { - slang::TypeReflection* type = unwrap_array(type_layout)->getType(); + slang::TypeLayoutReflection* element_type_layout = dst_type_layout->getElementTypeLayout(); + size_t dst_element_size = element_type_layout->getSize(); + size_t src_element_size = src_size / src_dimension; SGL_CHECK( - (TypeReflection::Kind)type->getKind() == TypeReflection::Kind::vector, + (TypeReflection::Kind)dst_type_layout->getKind() == TypeReflection::Kind::vector, "\"{}\" cannot bind a vector value", - type_layout->getName() + dst_type_layout->getName() ); SGL_CHECK( - type->getColumnCount() == uint32_t(dimension), + dst_type_layout->getColumnCount() == uint32_t(src_dimension), "\"{}\" expects a vector with dimension {} (got dimension {})", - type_layout->getName(), - type->getColumnCount(), - dimension + dst_type_layout->getName(), + dst_type_layout->getColumnCount(), + src_dimension ); SGL_CHECK( - allow_scalar_conversion(scalar_type, (TypeReflection::ScalarType)type->getScalarType()), - "\"{}\" expects a vector with scalar type {} (no implicit conversion from type {})", - type_layout->getName(), - (TypeReflection::ScalarType)type->getScalarType(), - scalar_type + dst_element_size == src_element_size + && allow_scalar_conversion( + src_scalar_type, + (TypeReflection::ScalarType)dst_type_layout->getScalarType() + ), + "\"{}\" expects a vector with scalar type {} ({}B) (no implicit conversion from type {} ({}B))", + dst_type_layout->getName(), + (TypeReflection::ScalarType)element_type_layout->getScalarType(), + dst_element_size, + src_scalar_type, + src_element_size ); SGL_CHECK( - type_layout->getSize() >= size, + src_size <= dst_type_layout->getSize(), "Mismatched size, writing {} B into backend type ({}) of only {} B.", - size, - type_layout->getName(), - type_layout->getSize() + src_size, + dst_type_layout->getName(), + dst_type_layout->getSize() ); } void check_matrix( - slang::TypeLayoutReflection* type_layout, - size_t size, - TypeReflection::ScalarType scalar_type, - int rows, - int cols + slang::TypeLayoutReflection* dst_type_layout, + size_t src_size, + TypeReflection::ScalarType src_scalar_type, + int src_rows, + int src_cols ) { - slang::TypeReflection* type = unwrap_array(type_layout)->getType(); + // Element of `matrix` is a vector, so the `scalar` is element applied twice. + slang::TypeLayoutReflection* element_type_layout + = dst_type_layout->getElementTypeLayout()->getElementTypeLayout(); + size_t dst_element_size = element_type_layout->getSize(); + size_t src_element_size = src_size / (src_rows * src_cols); SGL_CHECK( - (TypeReflection::Kind)type->getKind() == TypeReflection::Kind::matrix, + (TypeReflection::Kind)dst_type_layout->getKind() == TypeReflection::Kind::matrix, "\"{}\" cannot bind a matrix value", - type_layout->getName() + dst_type_layout->getName() ); - bool dimensionCondition = type->getRowCount() == uint32_t(rows) && type->getColumnCount() == uint32_t(cols); + bool dimensionCondition = dst_type_layout->getRowCount() == uint32_t(src_rows) + && dst_type_layout->getColumnCount() == uint32_t(src_cols); SGL_CHECK( dimensionCondition, "\"{}\" expects a matrix with dimension {}x{} (got dimension {}x{})", - type_layout->getName(), - type->getRowCount(), - type->getColumnCount(), - rows, - cols + dst_type_layout->getName(), + element_type_layout->getRowCount(), + element_type_layout->getColumnCount(), + src_rows, + src_cols ); SGL_CHECK( - allow_scalar_conversion(scalar_type, (TypeReflection::ScalarType)type->getScalarType()), - "\"{}\" expects a matrix with scalar type {} (no implicit conversion from type {})", - type_layout->getName(), - (TypeReflection::ScalarType)type->getScalarType(), - scalar_type + dst_element_size == src_element_size + && allow_scalar_conversion( + src_scalar_type, + (TypeReflection::ScalarType)element_type_layout->getScalarType() + ), + "\"{}\" expects a matrix with scalar type {} ({}B) (no implicit conversion from type {} ({}B))", + dst_type_layout->getName(), + (TypeReflection::ScalarType)element_type_layout->getScalarType(), + dst_element_size, + src_scalar_type, + src_element_size ); SGL_CHECK( - type_layout->getSize() >= size, + src_size <= dst_type_layout->getSize(), "Mismatched size, writing {} B into backend type ({}) of only {} B.", - size, - type_layout->getName(), - type_layout->getSize() + src_size, + dst_type_layout->getName(), + dst_type_layout->getSize() ); } diff --git a/src/sgl/device/cursor_utils.h b/src/sgl/device/cursor_utils.h index 425b6063..b545b7d8 100644 --- a/src/sgl/device/cursor_utils.h +++ b/src/sgl/device/cursor_utils.h @@ -8,8 +8,6 @@ namespace sgl { namespace cursor_utils { - size_t get_scalar_type_size(TypeReflection::ScalarType type); - slang::TypeLayoutReflection* unwrap_array(slang::TypeLayoutReflection* layout); void check_array( diff --git a/src/sgl/device/shader_cursor.cpp b/src/sgl/device/shader_cursor.cpp index 558cc4b2..f17eafd6 100644 --- a/src/sgl/device/shader_cursor.cpp +++ b/src/sgl/device/shader_cursor.cpp @@ -479,6 +479,69 @@ void ShaderCursor::set_cuda_tensor_view(const cuda::TensorView& tensor_view) con } } +template +void ShaderCursor::write_data( + ShaderOffset dst_offset, + uint32_t dst_stride, + const void* src_data, + size_t src_stride, + size_t dimension, + Func&& convert +) const +{ + const uint8_t* src_ptr = reinterpret_cast(src_data); + for (size_t i = 0; i < dimension; ++i, dst_offset.uniform_offset += dst_stride, src_ptr += src_stride) { + const TDst& dst = convert(*reinterpret_cast(src_ptr)); + m_shader_object->set_data(dst_offset, &dst, sizeof(TDst)); + } +} + +void ShaderCursor::set_bool_array(const void* data, size_t src_size, size_t element_count) const +{ + size_t src_element_size = src_size / element_count; + size_t src_element_stride = src_element_size; + size_t dst_element_size = m_type_layout->getElementTypeLayout()->getSize(); + uint32_t dst_element_stride + = narrow_cast(m_type_layout->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM)); + + if (m_type_layout->getSize() == src_size) { + m_shader_object->set_data(m_offset, data, src_size); + return; + } + + if (src_element_size == 1) // cpu bool + { + if (dst_element_size == 4) // d4d12, vulkan + { + write_data( + m_offset, + dst_element_stride, + data, + src_element_stride, + element_count, + [](bool b) -> uint32_t { return b ? 1 : 0; } + ); + return; + } else if (dst_element_size == 1) // cuda, metal + { + write_data( + m_offset, + dst_element_stride, + data, + src_element_stride, + element_count, + [](bool b) -> uint8_t { return b ? 1 : 0; } + ); + return; + } + } + SGL_THROW( + "Invalid configuration of bool array write, source is {}B, device is {}B.", + src_element_size, + dst_element_size + ); +} + void ShaderCursor::_set_array( const void* data, size_t size, @@ -486,8 +549,11 @@ void ShaderCursor::_set_array( size_t element_count ) const { - slang::TypeReflection* element_type = cursor_utils::unwrap_array(m_type_layout)->getType(); - size_t element_size = cursor_utils::get_scalar_type_size((TypeReflection::ScalarType)element_type->getScalarType()); + if (scalar_type == TypeReflection::ScalarType::bool_) + return set_bool_array(data, size, element_count); + + slang::TypeLayoutReflection* element_type_layout = cursor_utils::unwrap_array(m_type_layout); + size_t element_size = element_type_layout->getSize(); #ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS cursor_utils::check_array(m_type_layout, size, scalar_type, element_count); @@ -510,8 +576,11 @@ void ShaderCursor::_set_array( void ShaderCursor::_set_array_unsafe(const void* data, size_t size, size_t element_count) const { - slang::TypeReflection* element_type = cursor_utils::unwrap_array(m_type_layout)->getType(); - size_t element_size = cursor_utils::get_scalar_type_size((TypeReflection::ScalarType)element_type->getScalarType()); + slang::TypeLayoutReflection* element_type_layout = cursor_utils::unwrap_array(m_type_layout); + size_t element_size = element_type_layout->getSize(); + + // Check that we are not writing too much memory. + SGL_ASSERT(size <= element_size * element_count); size_t stride = m_type_layout->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); if (element_size == stride) { @@ -565,12 +634,15 @@ void ShaderCursor::_set_matrix( if (rows > 1) { size_t mat_stride = m_type_layout->getStride(); size_t row_stride = mat_stride / rows; - - size_t row_size = size / rows; - ShaderOffset offset = m_offset; - for (int row = 0; row < rows; ++row) { - m_shader_object->set_data(offset, reinterpret_cast(data) + row * row_size, row_size); - offset.uniform_offset += narrow_cast(row_stride); + if (mat_stride == size) { + m_shader_object->set_data(m_offset, data, size); + } else { + size_t row_size = size / rows; + ShaderOffset offset = m_offset; + for (int row = 0; row < rows; ++row) { + m_shader_object->set_data(offset, reinterpret_cast(data) + row * row_size, row_size); + offset.uniform_offset += narrow_cast(row_stride); + } } } else { m_shader_object->set_data(m_offset, data, size); @@ -714,66 +786,64 @@ SET_SCALAR(double, float64); template<> SGL_API void ShaderCursor::set(const bool& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { + if (slang_type_layout()->getSize() == 1) { + SGL_ASSERT_GE(slang_type_layout()->getSize(), sizeof(value)); _set_scalar(&value, sizeof(value), TypeReflection::ScalarType::bool_); return; } -#endif uint v = value ? 1 : 0; + SGL_ASSERT_GE(slang_type_layout()->getSize(), sizeof(v)); _set_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); } -template<> -SGL_API void ShaderCursor::set(const bool1& value) const +template +void ShaderCursor::set_boolN(const sgl::math::vector& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 1); + /// Workaround for issue: https://github.com/shader-slang/slang/issues/7441 + if (m_shader_object->device()->type() == DeviceType::cuda) { + sgl::math::vector v; + for (int i = 0; i < N; ++i) + v[i] = value[i] ? 1 : 0; + set_data(&v, sizeof(v)); return; } -#endif - uint1 v(value.x ? 1 : 0); - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); + + if (slang_type_layout()->getElementTypeLayout()->getSize() == 1) { + SGL_ASSERT_GE(slang_type_layout()->getSize(), sizeof(value)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), 1); + _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, N); + return; + } + sgl::math::vector v; + for (int i = 0; i < N; ++i) + v[i] = value[i] ? 1 : 0; + SGL_ASSERT_GE(slang_type_layout()->getStride(), sizeof(v)); + SGL_ASSERT_EQ(slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM), sizeof(uint32_t)); + _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, N); +} + +template<> +SGL_API void ShaderCursor::set(const bool1& value) const +{ + set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool2& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 2); - return; - } -#endif - uint2 v = {value.x ? 1 : 0, value.y ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); + set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool3& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 3); - return; - } -#endif - uint3 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); + set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool4& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 4); - return; - } -#endif - uint4 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0, value.w ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); + set_boolN(value); } } // namespace sgl diff --git a/src/sgl/device/shader_cursor.h b/src/sgl/device/shader_cursor.h index 4bac5ffb..e17078a8 100644 --- a/src/sgl/device/shader_cursor.h +++ b/src/sgl/device/shader_cursor.h @@ -89,6 +89,21 @@ class SGL_API ShaderCursor { void _set_matrix(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols) const; private: + void set_bool_array(const void* data, size_t size, size_t element_count) const; + + template + void set_boolN(const sgl::math::vector& value) const; + + template + void write_data( + ShaderOffset dst_offset, + uint32_t dst_stride, + const void* src_data, + size_t src_stride, + size_t dimension, + Func&& convert + ) const; + slang::TypeLayoutReflection* m_type_layout; ShaderObject* m_shader_object{nullptr}; ShaderOffset m_offset; diff --git a/src/slangpy_ext/device/cursor_utils.h b/src/slangpy_ext/device/cursor_utils.h index e8ee6b17..65bf69be 100644 --- a/src/slangpy_ext/device/cursor_utils.h +++ b/src/slangpy_ext/device/cursor_utils.h @@ -55,6 +55,8 @@ inline std::optional dtype_to_scalar_type(nb::dlpack case 64: return TypeReflection::ScalarType::float64; } + case uint8_t(nb::dlpack::dtype_code::Bool): + return TypeReflection::ScalarType::bool_; break; } return {}; @@ -518,6 +520,13 @@ class WriteConverterTable { } case TypeReflection::Kind::constant_buffer: case TypeReflection::Kind::parameter_block: + if constexpr (requires { self.dereference(); }) { + // Unwrap constant buffers or parameter blocks for shader cursors + auto child = self.dereference(); + write_internal(child, nbval); + return; + } else + SGL_THROW("constant_buffer and param_block not expected in BufferElementCursor"); case TypeReflection::Kind::struct_: { // Unwrap constant buffers or parameter blocks if (kind != TypeReflection::Kind::struct_) @@ -561,7 +570,7 @@ class WriteConverterTable { self._set_array( nbarray.data(), nbarray.nbytes(), - (TypeReflection::ScalarType)type_layout->getElementTypeLayout()->getType()->getScalarType(), + *dtype_to_scalar_type(nbarray.dtype()), narrow_cast(nbarray.shape(0)) ); return; @@ -622,15 +631,6 @@ class WriteConverterTable { self.set(val); } - /// Version of vector write specifically for bool vectors (which are stored as uint32_t) - template - requires IsSpecializationOfVector - inline static void _write_bool_vector_from_numpy(CursorType& self, nb::ndarray nbarray) - { - SGL_CHECK(nbarray.nbytes() == ValType::dimension * 4, "numpy array has wrong size."); - self._set_vector(nbarray.data(), nbarray.nbytes(), TypeReflection::ScalarType::bool_, ValType::dimension); - } - /// Write vector value to buffer element cursor from Python object template requires IsSpecializationOfVector @@ -682,7 +682,7 @@ class WriteConverterTable { for (size_t i = 0; i < nbarray.ndim(); ++i) dimension *= nbarray.shape(i); SGL_CHECK(dimension == ValType::dimension, "numpy array has wrong dimension."); - _write_bool_vector_from_numpy(self, nbarray); + _write_vector_from_numpy(self, nbarray); } else if (nb::isinstance(nbval)) { // A list or tuple. Attempt to cast each element of list to element of vector. auto seq = nb::cast(nbval); diff --git a/src/slangpy_ext/device/reflection.cpp b/src/slangpy_ext/device/reflection.cpp index 59da0256..8e7c5e19 100644 --- a/src/slangpy_ext/device/reflection.cpp +++ b/src/slangpy_ext/device/reflection.cpp @@ -131,6 +131,12 @@ SGL_PY_EXPORT(device_reflection) D(TypeLayoutReflection, element_type_layout) ) .def("unwrap_array", &TypeLayoutReflection::unwrap_array, D(TypeLayoutReflection, unwrap_array)) + .def( + "element_stride", + &TypeLayoutReflection::element_stride, + "category"_a = TypeReflection::ParameterCategory::uniform, + D(TypeLayoutReflection, element_stride) + ) .def("__repr__", &TypeLayoutReflection::to_string); bind_list_type(m, "TypeLayoutReflectionFieldList", D(TypeLayoutReflectionFieldList));