diff --git a/slangpy/tests/device/test_buffer_cursor.py b/slangpy/tests/device/test_buffer_cursor.py index 9a010fc5..64cfe802 100644 --- a/slangpy/tests/device/test_buffer_cursor.py +++ b/slangpy/tests/device/test_buffer_cursor.py @@ -346,7 +346,7 @@ def test_cursor_read_write(device_type: spy.DeviceType, seed: int): (kernel, buffer_layout) = make_fill_in_module(device_type, tests) # Create a buffer cursor with its own data - cursor = spy.BufferCursor(buffer_layout.element_type_layout, 1) + cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1) # Populate the first element element = cursor[0] @@ -355,7 +355,7 @@ def test_cursor_read_write(device_type: spy.DeviceType, seed: int): element[name] = value # Create new cursor by copying the first, and read element - cursor2 = spy.BufferCursor(buffer_layout.element_type_layout, 1) + cursor2 = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1) cursor2.copy_from_numpy(cursor.to_numpy()) element2 = cursor2[0] @@ -389,7 +389,7 @@ def test_fill_from_kernel(device_type: spy.DeviceType, seed: int): kernel.dispatch([count, 1, 1], buffer=buffer) # Create a cursor and read the buffer by copying its data - cursor = spy.BufferCursor(buffer_layout.element_type_layout, count) + cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, count) cursor.copy_from_numpy(buffer.to_numpy()) # Verify data matches @@ -453,7 +453,7 @@ def test_cursor_lifetime(device_type: spy.DeviceType): (kernel, buffer_layout) = make_fill_in_module(device_type, get_tests(device_type)) # Create a buffer cursor with its own data - cursor = spy.BufferCursor(buffer_layout.element_type_layout, 1) + cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1) # Get element element = cursor[0] diff --git a/src/sgl/device/buffer_cursor.cpp b/src/sgl/device/buffer_cursor.cpp index 2ed84e7f..9c5ea1dd 100644 --- a/src/sgl/device/buffer_cursor.cpp +++ b/src/sgl/device/buffer_cursor.cpp @@ -108,125 +108,19 @@ void BufferElementCursor::set_data(const void* data, size_t size) write_data(m_offset, data, size); } -void BufferElementCursor::_set_array( - const void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - size_t element_count -) +DeviceType BufferElementCursor::_get_device_type() const { - ref element_type = m_type_layout->unwrap_array()->type(); - size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type()); - - cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count); - - size_t stride = m_type_layout->element_stride(); - if (element_size == stride) { - write_data(m_offset, data, size); - } else { - size_t offset = m_offset; - for (size_t i = 0; i < element_count; ++i) { - write_data(offset, reinterpret_cast(data) + i * element_size, element_size); - offset += stride; - } - } + return m_buffer->get_device_type(); } -void BufferElementCursor::_get_array( - void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - size_t element_count -) const -{ - ref element_type = m_type_layout->unwrap_array()->type(); - size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type()); - - cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count); - - size_t stride = m_type_layout->element_stride(); - if (element_size == stride) { - read_data(m_offset, data, size); - } else { - size_t offset = m_offset; - for (size_t i = 0; i < element_count; ++i) { - read_data(offset, reinterpret_cast(data) + i * element_size, element_size); - offset += stride; - } - } -} -void BufferElementCursor::_set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type) -{ - cursor_utils::check_scalar(m_type_layout->slang_target(), size, scalar_type); - write_data(m_offset, data, size); -} - -void BufferElementCursor::_get_scalar(void* data, size_t size, TypeReflection::ScalarType scalar_type) const -{ - cursor_utils::check_scalar(m_type_layout->slang_target(), size, scalar_type); - read_data(m_offset, data, size); -} - -void BufferElementCursor::_set_vector( - const void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - int dimension -) -{ - cursor_utils::check_vector(m_type_layout->slang_target(), size, scalar_type, dimension); - write_data(m_offset, data, size); -} - -void BufferElementCursor::_get_vector(void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) - const -{ - cursor_utils::check_vector(m_type_layout->slang_target(), size, scalar_type, dimension); - read_data(m_offset, data, size); -} - -void BufferElementCursor::_set_matrix( - const void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - int rows, - int cols -) -{ - cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols); - size_t stride = slang_type_layout()->getStride(); - if (stride != size) { - size_t row_stride = stride / rows; - size_t row_size = size / rows; - for (int i = 0; i < rows; ++i) { - write_data(m_offset + i * row_stride, reinterpret_cast(data) + i * row_size, row_size); - } - } else { - write_data(m_offset, data, size); - } -} - -void BufferElementCursor::_get_matrix( - void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - int rows, - int cols -) const -{ - cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols); - size_t stride = slang_type_layout()->getStride(); - if (stride != size) { - size_t row_stride = stride / rows; - size_t row_size = size / rows; - for (int i = 0; i < rows; ++i) { - read_data(m_offset + i * row_stride, reinterpret_cast(data) + i * row_size, row_size); - } - } else { - read_data(m_offset, data, size); - } -} +// Explicit instantiation of the methods +template void +CursorWriteWrappers::_set_array(const void*, size_t, TypeReflection::ScalarType, size_t) + const; +template void +CursorWriteWrappers::_set_vector(const void*, size_t, TypeReflection::ScalarType, int) + const; // // Setter specializations @@ -332,106 +226,59 @@ GETSET_SCALAR(double, float64); template<> SGL_API void BufferElementCursor::set(const bool& value) { - uint v = value ? 1 : 0; - _set_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); + _set_bool(value); } template<> SGL_API void BufferElementCursor::get(bool& value) const { - uint v; - _get_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); - value = v != 0; + _get_bool(value); } template<> SGL_API void BufferElementCursor::set(const bool1& value) { -#if SGL_MACOS - bool1 v = value; -#else - uint1 v(value.x ? 1 : 0); -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); + _set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool1& value) const { -#if SGL_MACOS - bool1 v; -#else - uint1 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); - value = bool1(v.x != 0); + _get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool2& value) { -#if SGL_MACOS - bool2 v = value; -#else - uint2 v = {value.x ? 1 : 0, value.y ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); + _set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool2& value) const { -#if SGL_MACOS - bool2 v; -#else - uint2 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); - value = {v.x != 0, v.y != 0}; + _get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool3& value) { -#if SGL_MACOS - bool3 v = value; -#else - uint3 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); + _set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool3& value) const { -#if SGL_MACOS - bool3 v; -#else - uint3 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); - value = {v.x != 0, v.y != 0, v.z != 0}; + _get_boolN(value); } template<> SGL_API void BufferElementCursor::set(const bool4& value) { -#if SGL_MACOS - bool4 v = value; -#else - uint4 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0, value.w ? 1 : 0}; -#endif - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); + _set_boolN(value); } template<> SGL_API void BufferElementCursor::get(bool4& value) const { -#if SGL_MACOS - bool4 v; -#else - uint4 v; -#endif - _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); - value = {v.x != 0, v.y != 0, v.z != 0, v.w != 0}; + _get_boolN(value); } template<> @@ -446,7 +293,7 @@ SGL_API void BufferElementCursor::get(DescriptorHandle& value) const read_data(m_offset, &value.value, sizeof(value.value)); } -void BufferElementCursor::write_data(size_t offset, const void* data, size_t size) +void BufferElementCursor::write_data(size_t offset, const void* data, size_t size) const { m_buffer->write_data(offset, data, size); } @@ -456,16 +303,18 @@ void BufferElementCursor::read_data(size_t offset, void* data, size_t size) cons m_buffer->read_data(offset, data, size); } -BufferCursor::BufferCursor(ref element_layout, void* data, size_t size) +BufferCursor::BufferCursor(DeviceType device_type, ref element_layout, void* data, size_t size) : m_element_type_layout(std::move(element_layout)) + , m_device_type(device_type) , m_buffer((uint8_t*)data) , m_size(size) , m_owner(false) { } -BufferCursor::BufferCursor(ref element_layout, size_t element_count) +BufferCursor::BufferCursor(DeviceType device_type, ref element_layout, size_t element_count) : m_element_type_layout(std::move(element_layout)) + , m_device_type(device_type) { m_size = element_count * m_element_type_layout->stride(); m_buffer = new uint8_t[m_size]; @@ -474,6 +323,7 @@ BufferCursor::BufferCursor(ref element_layout, size_t elem BufferCursor::BufferCursor(ref element_layout, ref resource, bool load_before_write) : m_element_type_layout(std::move(element_layout)) + , m_device_type(resource->device()->type()) { m_resource = std::move(resource); m_size = m_resource->size(); diff --git a/src/sgl/device/buffer_cursor.h b/src/sgl/device/buffer_cursor.h index 24e4b478..8af06dd8 100644 --- a/src/sgl/device/buffer_cursor.h +++ b/src/sgl/device/buffer_cursor.h @@ -6,17 +6,21 @@ #include "sgl/device/shader_offset.h" #include "sgl/device/reflection.h" #include "sgl/device/cursor_utils.h" +#include "sgl/device/device.h" #include "sgl/core/config.h" #include "sgl/core/macros.h" +#include "sgl/device/cursor_access_wrappers.h" + #include namespace sgl { /// Represents a single element of a given type in a block of memory, and /// provides read/write tools to access its members via reflection. -class SGL_API BufferElementCursor { +class SGL_API BufferElementCursor : public CursorWriteWrappers, + public CursorReadWrappers { public: BufferElementCursor() = default; @@ -65,20 +69,17 @@ class SGL_API BufferElementCursor { template void set(const T& value); - void _set_array(const void* data, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count); - void _set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type); - void _set_vector(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension); - void _set_matrix(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols); - - void _get_array(void* data, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count) const; - void _get_scalar(void* data, size_t size, TypeReflection::ScalarType scalar_type) const; - void _get_vector(void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) const; - void _get_matrix(void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols) const; - void _set_offset(size_t new_offset) { m_offset = new_offset; } + /// CursorWriteWrappers, CursorReadWrappers + void _set_data(size_t offset, const void* data, size_t size) const { write_data(offset, data, size); } + void _get_data(size_t offset, void* data, size_t size) const { return read_data(offset, data, size); } + size_t _get_offset() const { return m_offset; } + static size_t _increment_offset(size_t offset, size_t diff) { return offset + diff; } + DeviceType _get_device_type() const; + private: - void write_data(size_t offset, const void* data, size_t size); + void write_data(size_t offset, const void* data, size_t size) const; void read_data(size_t offset, void* data, size_t size) const; ref m_type_layout; @@ -98,10 +99,10 @@ class SGL_API BufferCursor : Object { /// Create with none-owning view of specific block of memory. Number of /// elements is inferred from the size of the block and the type layout. - BufferCursor(ref element_layout, void* data, size_t size); + BufferCursor(DeviceType device_type, ref element_layout, void* data, size_t size); /// Create buffer + allocate space internally for a given number of elements. - BufferCursor(ref element_layout, size_t element_count); + BufferCursor(DeviceType device_type, ref element_layout, size_t element_count); /// Create as a view onto a buffer resource. Disable load_before_write to /// prevent automatic loading of current buffer state before writing data to it. @@ -161,9 +162,13 @@ class SGL_API BufferCursor : Object { /// Get the resource this cursor represents (if any). ref resource() const { return m_resource; } + /// Get device type that determines the data layout rules. + DeviceType get_device_type() const { return m_device_type; } + private: ref m_element_type_layout; ref m_resource; + DeviceType m_device_type; uint8_t* m_buffer{nullptr}; size_t m_size{0}; bool m_owner{false}; diff --git a/src/sgl/device/cursor_access_wrappers.h b/src/sgl/device/cursor_access_wrappers.h new file mode 100644 index 00000000..b7ead718 --- /dev/null +++ b/src/sgl/device/cursor_access_wrappers.h @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include "sgl/device/fwd.h" +#include "sgl/device/device.h" +#include "sgl/device/reflection.h" +#include "sgl/device/cursor_utils.h" + +// TODO: Decide if we want to disable / optimize type checks +// currently can represent 50% of the cost of writes in +// certain situations. +#define SGL_ENABLE_CURSOR_TYPE_CHECKS + +namespace sgl { + +template +class SGL_API CursorWriteWrappers { + using BaseCursorOffset = TOffset; + +public: + void _set_array(const void* data, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count) const + { + slang::TypeReflection* element_type = cursor_utils::unwrap_array(_get_slang_type_layout())->getType(); + size_t element_size + = cursor_utils::get_scalar_type_size((TypeReflection::ScalarType)element_type->getScalarType()); + +#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS + cursor_utils::check_array(_get_slang_type_layout(), size, scalar_type, element_count); +#else + SGL_UNUSED(scalar_type); + SGL_UNUSED(element_count); +#endif + + size_t stride = _get_slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); + if (element_size == stride) { + _set_data_internal(_get_offset_internal(), data, size); + } else { + auto offset = _get_offset_internal(); + for (size_t i = 0; i < element_count; ++i) { + _set_data_internal(offset, reinterpret_cast(data) + i * element_size, element_size); + offset = _increment_offset_internal(offset, stride); + } + } + } + + void _set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type) const + { +#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS + cursor_utils::check_scalar(_get_slang_type_layout(), size, scalar_type); +#else + SGL_UNUSED(scalar_type); +#endif + _set_data_internal(_get_offset_internal(), data, size); + } + + void _set_vector(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) const + { +#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS + cursor_utils::check_vector(_get_slang_type_layout(), size, scalar_type, dimension); +#else + SGL_UNUSED(scalar_type); + SGL_UNUSED(dimension); +#endif + _set_data_internal(_get_offset_internal(), data, size); + } + + void _set_matrix(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols) const + { +#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS + cursor_utils::check_matrix(_get_slang_type_layout(), size, scalar_type, rows, cols); +#else + SGL_UNUSED(scalar_type); + SGL_UNUSED(cols); +#endif + + if (rows > 1) { + size_t mat_stride = _get_slang_type_layout()->getStride(); + size_t row_stride = mat_stride / rows; + size_t row_size = size / rows; + + auto offset = _get_offset_internal(); + for (int row = 0; row < rows; ++row) { + _set_data_internal(offset, reinterpret_cast(data) + row * row_size, row_size); + offset = _increment_offset_internal(offset, row_stride); + } + } else { + _set_data_internal(_get_offset_internal(), data, size); + } + } + + void _set_bool(const bool& value) const + { +#if SGL_MACOS + if (_get_device_type_internal() == DeviceType::metal) { + _set_scalar(&value, sizeof(value), TypeReflection::ScalarType::bool_); + return; + } +#endif + uint32_t v = value ? 1 : 0; + _set_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); + } + + template + void _set_boolN(const sgl::math::vector& value) const + { +#if SGL_MACOS + if (_get_device_type_internal() == DeviceType::metal) { + _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 1); + return; + } +#endif + + sgl::math::vector v; + for (int i = 0; i < N; ++i) + v[i] = value[i] ? 1 : 0; + _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, N); + } + +protected: + CursorWriteWrappers() = default; + +private: + void _set_data_internal(BaseCursorOffset offset, const void* data, size_t size) const + { + static_cast(this)->_set_data(offset, data, size); + } + + BaseCursorOffset _get_offset_internal() const { return static_cast(this)->_get_offset(); } + BaseCursorOffset _increment_offset_internal(BaseCursorOffset offset, size_t diff) const + { + return BaseCursor::_increment_offset(offset, diff); + } + + slang::TypeLayoutReflection* _get_slang_type_layout() const + { + return static_cast(this)->slang_type_layout(); + } + + DeviceType _get_device_type_internal() const { return static_cast(this)->_get_device_type(); } +}; + +template +class SGL_API CursorReadWrappers { + using BaseCursorOffset = TOffset; + +public: + void _get_array(void* data, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count) const + { + ref element_type = _get_slang_type_layout()->unwrap_array()->type(); + size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type()); + + cursor_utils::check_array(_get_slang_type_layout()->slang_target(), size, scalar_type, element_count); + + size_t stride = _get_slang_type_layout()->element_stride(); + if (element_size == stride) { + _get_data_internal(_get_offset_internal(), data, size); + } else { + auto offset = _get_offset_internal(); + for (size_t i = 0; i < element_count; ++i) { + read_data(offset, reinterpret_cast(data) + i * element_size, element_size); + offset = _increment_offset_internal(offset, stride); + } + } + } + + void _get_scalar(void* data, size_t size, TypeReflection::ScalarType scalar_type) const + { + cursor_utils::check_scalar(_get_slang_type_layout(), size, scalar_type); + _get_data_internal(_get_offset_internal(), data, size); + } + + void _get_vector(void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) const + { + cursor_utils::check_vector(_get_slang_type_layout(), size, scalar_type, dimension); + _get_data_internal(_get_offset_internal(), data, size); + } + + void _get_matrix(void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols) const + { + cursor_utils::check_matrix(_get_slang_type_layout(), size, scalar_type, rows, cols); + size_t stride = _get_slang_type_layout()->getStride(); + if (stride != size) { + size_t row_stride = stride / rows; + size_t row_size = size / rows; + auto offset = _get_offset_internal(); + for (int i = 0; i < rows; ++i) { + _get_data_internal(offset, reinterpret_cast(data) + i * row_size, row_size); + offset = _increment_offset_internal(offset, row_stride); + } + } else { + _get_data_internal(_get_offset_internal(), data, size); + } + } + + void _get_bool(bool& value) const + { +#if SGL_MACOS + if (_get_device_type_internal() == DeviceType::metal) { + _get_scalar(&value, sizeof(value), TypeReflection::ScalarType::bool_); + return; + } +#endif + uint32_t v; + _get_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); + value = (v != 0); + } + + template + void _get_boolN(sgl::math::vector& value) const + { +#if SGL_MACOS + if (_get_device_type_internal() == DeviceType::metal) { + _get_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, N); + return; + } +#endif + sgl::math::vector v; + _get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, N); + for (int i = 0; i < N; ++i) + value[i] = (v[i] != 0); + } + +protected: + CursorReadWrappers() = default; + +private: + void _get_data_internal(BaseCursorOffset offset, void* data, size_t size) const + { + static_cast(this)->_get_data(offset, data, size); + } + + BaseCursorOffset _get_offset_internal() const { return static_cast(this)->_get_offset(); } + BaseCursorOffset _increment_offset_internal(BaseCursorOffset offset, size_t diff) const + { + return BaseCursor::_increment_offset(offset, diff); + } + + slang::TypeLayoutReflection* _get_slang_type_layout() const + { + return static_cast(this)->slang_type_layout(); + } + + DeviceType _get_device_type_internal() const { return static_cast(this)->_get_device_type(); } +}; + + +} // namespace sgl diff --git a/src/sgl/device/cursor_utils.h b/src/sgl/device/cursor_utils.h index 425b6063..4c23995f 100644 --- a/src/sgl/device/cursor_utils.h +++ b/src/sgl/device/cursor_utils.h @@ -5,30 +5,33 @@ #include "sgl/device/fwd.h" #include "sgl/device/reflection.h" +#include "sgl/core/macros.h" + namespace sgl { namespace cursor_utils { - size_t get_scalar_type_size(TypeReflection::ScalarType type); + SGL_API size_t get_scalar_type_size(TypeReflection::ScalarType type); - slang::TypeLayoutReflection* unwrap_array(slang::TypeLayoutReflection* layout); + SGL_API slang::TypeLayoutReflection* unwrap_array(slang::TypeLayoutReflection* layout); - void check_array( + SGL_API void check_array( slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count ); - void check_scalar(slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type); + SGL_API void + check_scalar(slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type); - void check_vector( + SGL_API void check_vector( slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type, int dimension ); - void check_matrix( + SGL_API void check_matrix( slang::TypeLayoutReflection* type_layout, size_t size, TypeReflection::ScalarType scalar_type, diff --git a/src/sgl/device/shader_cursor.cpp b/src/sgl/device/shader_cursor.cpp index 558cc4b2..bac5c5c9 100644 --- a/src/sgl/device/shader_cursor.cpp +++ b/src/sgl/device/shader_cursor.cpp @@ -479,35 +479,6 @@ void ShaderCursor::set_cuda_tensor_view(const cuda::TensorView& tensor_view) con } } -void ShaderCursor::_set_array( - const void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - size_t element_count -) const -{ - slang::TypeReflection* element_type = cursor_utils::unwrap_array(m_type_layout)->getType(); - size_t element_size = cursor_utils::get_scalar_type_size((TypeReflection::ScalarType)element_type->getScalarType()); - -#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS - cursor_utils::check_array(m_type_layout, size, scalar_type, element_count); -#else - SGL_UNUSED(scalar_type); - SGL_UNUSED(element_count); -#endif - - size_t stride = m_type_layout->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); - if (element_size == stride) { - m_shader_object->set_data(m_offset, data, size); - } else { - ShaderOffset offset = m_offset; - for (size_t i = 0; i < element_count; ++i) { - m_shader_object->set_data(offset, reinterpret_cast(data) + i * element_size, element_size); - offset.uniform_offset += narrow_cast(stride); - } - } -} - void ShaderCursor::_set_array_unsafe(const void* data, size_t size, size_t element_count) const { slang::TypeReflection* element_type = cursor_utils::unwrap_array(m_type_layout)->getType(); @@ -525,57 +496,24 @@ void ShaderCursor::_set_array_unsafe(const void* data, size_t size, size_t eleme } } -void ShaderCursor::_set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type) const +void ShaderCursor::_set_data(ShaderOffset offset, const void* data, size_t size) const { -#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS - cursor_utils::check_scalar(m_type_layout, size, scalar_type); -#else - SGL_UNUSED(scalar_type); -#endif - m_shader_object->set_data(m_offset, data, size); + m_shader_object->set_data(offset, data, size); } -void ShaderCursor::_set_vector(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) - const +DeviceType ShaderCursor::_get_device_type() const { -#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS - cursor_utils::check_vector(m_type_layout, size, scalar_type, dimension); -#else - SGL_UNUSED(scalar_type); - SGL_UNUSED(dimension); -#endif - m_shader_object->set_data(m_offset, data, size); + return m_shader_object->device()->type(); } -void ShaderCursor::_set_matrix( - const void* data, - size_t size, - TypeReflection::ScalarType scalar_type, - int rows, - int cols -) const -{ -#ifdef SGL_ENABLE_CURSOR_TYPE_CHECKS - cursor_utils::check_matrix(m_type_layout, size, scalar_type, rows, cols); -#else - SGL_UNUSED(scalar_type); - SGL_UNUSED(cols); -#endif - - if (rows > 1) { - size_t mat_stride = m_type_layout->getStride(); - size_t row_stride = mat_stride / rows; +// Explicit instantiation of the methods +template void +CursorWriteWrappers::_set_array(const void*, size_t, TypeReflection::ScalarType, size_t) + const; - size_t row_size = size / rows; - ShaderOffset offset = m_offset; - for (int row = 0; row < rows; ++row) { - m_shader_object->set_data(offset, reinterpret_cast(data) + row * row_size, row_size); - offset.uniform_offset += narrow_cast(row_stride); - } - } else { - m_shader_object->set_data(m_offset, data, size); - } -} +template void +CursorWriteWrappers::_set_vector(const void*, size_t, TypeReflection::ScalarType, int) + const; // // Setter specializations @@ -714,66 +652,31 @@ SET_SCALAR(double, float64); template<> SGL_API void ShaderCursor::set(const bool& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_scalar(&value, sizeof(value), TypeReflection::ScalarType::bool_); - return; - } -#endif - uint v = value ? 1 : 0; - _set_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_); + _set_bool(value); } template<> SGL_API void ShaderCursor::set(const bool1& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 1); - return; - } -#endif - uint1 v(value.x ? 1 : 0); - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1); + _set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool2& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 2); - return; - } -#endif - uint2 v = {value.x ? 1 : 0, value.y ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2); + _set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool3& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 3); - return; - } -#endif - uint3 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3); + _set_boolN(value); } template<> SGL_API void ShaderCursor::set(const bool4& value) const { -#if SGL_MACOS - if (m_shader_object->device()->type() == DeviceType::metal) { - _set_vector(&value, sizeof(value), TypeReflection::ScalarType::bool_, 4); - return; - } -#endif - uint4 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0, value.w ? 1 : 0}; - _set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4); + _set_boolN(value); } } // namespace sgl diff --git a/src/sgl/device/shader_cursor.h b/src/sgl/device/shader_cursor.h index 4bac5ffb..dd84fece 100644 --- a/src/sgl/device/shader_cursor.h +++ b/src/sgl/device/shader_cursor.h @@ -10,6 +10,8 @@ #include "sgl/core/config.h" #include "sgl/core/macros.h" +#include "sgl/device/cursor_access_wrappers.h" + #include namespace sgl { @@ -19,7 +21,7 @@ namespace sgl { /// allocating/freeing them repeatedly. This is far faster, however does introduce /// a risk of mem access problems if the shader cursor is kept alive longer than /// the shader object it was created from. -class SGL_API ShaderCursor { +class SGL_API ShaderCursor : public CursorWriteWrappers { public: ShaderCursor() = default; @@ -81,12 +83,17 @@ class SGL_API ShaderCursor { template void set(const T& value) const; - void _set_array(const void* data, size_t size, TypeReflection::ScalarType scalar_type, size_t element_count) const; void _set_array_unsafe(const void* data, size_t size, size_t element_count) const; - void _set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type) const; - void _set_vector(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension) const; - void _set_matrix(const void* data, size_t size, TypeReflection::ScalarType scalar_type, int rows, int cols) const; + /// CursorWriteWrappers, CursorReadWrappers + void _set_data(ShaderOffset offset, const void* data, size_t size) const; + ShaderOffset _get_offset() const { return m_offset; } + static ShaderOffset _increment_offset(ShaderOffset offset, size_t diff) + { + offset.uniform_offset += narrow_cast(diff); + return offset; + } + DeviceType _get_device_type() const; private: slang::TypeLayoutReflection* m_type_layout; diff --git a/src/slangpy_ext/device/buffer_cursor.cpp b/src/slangpy_ext/device/buffer_cursor.cpp index 4817538f..902275e8 100644 --- a/src/slangpy_ext/device/buffer_cursor.cpp +++ b/src/slangpy_ext/device/buffer_cursor.cpp @@ -76,7 +76,13 @@ SGL_PY_EXPORT(device_buffer_cursor) // Interface to simpler root cursor object that maps to the larger buffer. nb::class_(m, "BufferCursor", D(BufferCursor)) // - .def(nb::init, size_t>(), "element_layout"_a, "size"_a, D(BufferCursor, BufferCursor)) + .def( + nb::init, size_t>(), + "device_type"_a, + "element_layout"_a, + "size"_a, + D(BufferCursor, BufferCursor) + ) .def( nb::init, ref, bool>(), "element_layout"_a, diff --git a/src/slangpy_ext/device/cursor_utils.h b/src/slangpy_ext/device/cursor_utils.h index e8ee6b17..9c2aabfe 100644 --- a/src/slangpy_ext/device/cursor_utils.h +++ b/src/slangpy_ext/device/cursor_utils.h @@ -420,6 +420,8 @@ class WriteConverterTable { } else { SGL_THROW("Expected dict"); } + default: + break; } }