Skip to content

Refactored common cursors setters and getters #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions slangpy/tests/device/test_buffer_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_cursor_read_write(device_type: spy.DeviceType, seed: int):
(kernel, buffer_layout) = make_fill_in_module(device_type, tests)

# Create a buffer cursor with its own data
cursor = spy.BufferCursor(buffer_layout.element_type_layout, 1)
cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1)

# Populate the first element
element = cursor[0]
Expand All @@ -355,7 +355,7 @@ def test_cursor_read_write(device_type: spy.DeviceType, seed: int):
element[name] = value

# Create new cursor by copying the first, and read element
cursor2 = spy.BufferCursor(buffer_layout.element_type_layout, 1)
cursor2 = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1)
cursor2.copy_from_numpy(cursor.to_numpy())
element2 = cursor2[0]

Expand Down Expand Up @@ -389,7 +389,7 @@ def test_fill_from_kernel(device_type: spy.DeviceType, seed: int):
kernel.dispatch([count, 1, 1], buffer=buffer)

# Create a cursor and read the buffer by copying its data
cursor = spy.BufferCursor(buffer_layout.element_type_layout, count)
cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, count)
cursor.copy_from_numpy(buffer.to_numpy())

# Verify data matches
Expand Down Expand Up @@ -453,7 +453,7 @@ def test_cursor_lifetime(device_type: spy.DeviceType):
(kernel, buffer_layout) = make_fill_in_module(device_type, get_tests(device_type))

# Create a buffer cursor with its own data
cursor = spy.BufferCursor(buffer_layout.element_type_layout, 1)
cursor = spy.BufferCursor(device_type, buffer_layout.element_type_layout, 1)

# Get element
element = cursor[0]
Expand Down
200 changes: 25 additions & 175 deletions src/sgl/device/buffer_cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,125 +108,19 @@ void BufferElementCursor::set_data(const void* data, size_t size)
write_data(m_offset, data, size);
}

void BufferElementCursor::_set_array(
const void* data,
size_t size,
TypeReflection::ScalarType scalar_type,
size_t element_count
)
DeviceType BufferElementCursor::_get_device_type() const
{
ref<const TypeReflection> element_type = m_type_layout->unwrap_array()->type();
size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type());

cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count);

size_t stride = m_type_layout->element_stride();
if (element_size == stride) {
write_data(m_offset, data, size);
} else {
size_t offset = m_offset;
for (size_t i = 0; i < element_count; ++i) {
write_data(offset, reinterpret_cast<const uint8_t*>(data) + i * element_size, element_size);
offset += stride;
}
}
return m_buffer->get_device_type();
}

void BufferElementCursor::_get_array(
void* data,
size_t size,
TypeReflection::ScalarType scalar_type,
size_t element_count
) const
{
ref<const TypeReflection> element_type = m_type_layout->unwrap_array()->type();
size_t element_size = cursor_utils::get_scalar_type_size(element_type->scalar_type());

cursor_utils::check_array(m_type_layout->slang_target(), size, scalar_type, element_count);

size_t stride = m_type_layout->element_stride();
if (element_size == stride) {
read_data(m_offset, data, size);
} else {
size_t offset = m_offset;
for (size_t i = 0; i < element_count; ++i) {
read_data(offset, reinterpret_cast<uint8_t*>(data) + i * element_size, element_size);
offset += stride;
}
}
}
void BufferElementCursor::_set_scalar(const void* data, size_t size, TypeReflection::ScalarType scalar_type)
{
cursor_utils::check_scalar(m_type_layout->slang_target(), size, scalar_type);
write_data(m_offset, data, size);
}

void BufferElementCursor::_get_scalar(void* data, size_t size, TypeReflection::ScalarType scalar_type) const
{
cursor_utils::check_scalar(m_type_layout->slang_target(), size, scalar_type);
read_data(m_offset, data, size);
}

void BufferElementCursor::_set_vector(
const void* data,
size_t size,
TypeReflection::ScalarType scalar_type,
int dimension
)
{
cursor_utils::check_vector(m_type_layout->slang_target(), size, scalar_type, dimension);
write_data(m_offset, data, size);
}

void BufferElementCursor::_get_vector(void* data, size_t size, TypeReflection::ScalarType scalar_type, int dimension)
const
{
cursor_utils::check_vector(m_type_layout->slang_target(), size, scalar_type, dimension);
read_data(m_offset, data, size);
}

void BufferElementCursor::_set_matrix(
const void* data,
size_t size,
TypeReflection::ScalarType scalar_type,
int rows,
int cols
)
{
cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols);
size_t stride = slang_type_layout()->getStride();
if (stride != size) {
size_t row_stride = stride / rows;
size_t row_size = size / rows;
for (int i = 0; i < rows; ++i) {
write_data(m_offset + i * row_stride, reinterpret_cast<const uint8_t*>(data) + i * row_size, row_size);
}
} else {
write_data(m_offset, data, size);
}
}

void BufferElementCursor::_get_matrix(
void* data,
size_t size,
TypeReflection::ScalarType scalar_type,
int rows,
int cols
) const
{
cursor_utils::check_matrix(m_type_layout->slang_target(), size, scalar_type, rows, cols);
size_t stride = slang_type_layout()->getStride();
if (stride != size) {
size_t row_stride = stride / rows;
size_t row_size = size / rows;
for (int i = 0; i < rows; ++i) {
read_data(m_offset + i * row_stride, reinterpret_cast<uint8_t*>(data) + i * row_size, row_size);
}
} else {
read_data(m_offset, data, size);
}
}
// Explicit instantiation of the methods
template void
CursorWriteWrappers<BufferElementCursor, size_t>::_set_array(const void*, size_t, TypeReflection::ScalarType, size_t)
const;

template void
CursorWriteWrappers<BufferElementCursor, size_t>::_set_vector(const void*, size_t, TypeReflection::ScalarType, int)
const;

//
// Setter specializations
Expand Down Expand Up @@ -332,106 +226,59 @@ GETSET_SCALAR(double, float64);
template<>
SGL_API void BufferElementCursor::set(const bool& value)
{
uint v = value ? 1 : 0;
_set_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_);
_set_bool(value);
}
template<>
SGL_API void BufferElementCursor::get(bool& value) const
{
uint v;
_get_scalar(&v, sizeof(v), TypeReflection::ScalarType::bool_);
value = v != 0;
_get_bool(value);
}

template<>
SGL_API void BufferElementCursor::set(const bool1& value)
{
#if SGL_MACOS
bool1 v = value;
#else
uint1 v(value.x ? 1 : 0);
#endif
_set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1);
_set_boolN(value);
}
template<>
SGL_API void BufferElementCursor::get(bool1& value) const
{
#if SGL_MACOS
bool1 v;
#else
uint1 v;
#endif
_get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 1);
value = bool1(v.x != 0);
_get_boolN(value);
}

template<>
SGL_API void BufferElementCursor::set(const bool2& value)
{
#if SGL_MACOS
bool2 v = value;
#else
uint2 v = {value.x ? 1 : 0, value.y ? 1 : 0};
#endif
_set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2);
_set_boolN(value);
}

template<>
SGL_API void BufferElementCursor::get(bool2& value) const
{
#if SGL_MACOS
bool2 v;
#else
uint2 v;
#endif
_get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 2);
value = {v.x != 0, v.y != 0};
_get_boolN(value);
}

template<>
SGL_API void BufferElementCursor::set(const bool3& value)
{
#if SGL_MACOS
bool3 v = value;
#else
uint3 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0};
#endif
_set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3);
_set_boolN(value);
}

template<>
SGL_API void BufferElementCursor::get(bool3& value) const
{
#if SGL_MACOS
bool3 v;
#else
uint3 v;
#endif
_get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 3);
value = {v.x != 0, v.y != 0, v.z != 0};
_get_boolN(value);
}

template<>
SGL_API void BufferElementCursor::set(const bool4& value)
{
#if SGL_MACOS
bool4 v = value;
#else
uint4 v = {value.x ? 1 : 0, value.y ? 1 : 0, value.z ? 1 : 0, value.w ? 1 : 0};
#endif
_set_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4);
_set_boolN(value);
}

template<>
SGL_API void BufferElementCursor::get(bool4& value) const
{
#if SGL_MACOS
bool4 v;
#else
uint4 v;
#endif
_get_vector(&v, sizeof(v), TypeReflection::ScalarType::bool_, 4);
value = {v.x != 0, v.y != 0, v.z != 0, v.w != 0};
_get_boolN(value);
}

template<>
Expand All @@ -446,7 +293,7 @@ SGL_API void BufferElementCursor::get(DescriptorHandle& value) const
read_data(m_offset, &value.value, sizeof(value.value));
}

void BufferElementCursor::write_data(size_t offset, const void* data, size_t size)
void BufferElementCursor::write_data(size_t offset, const void* data, size_t size) const
{
m_buffer->write_data(offset, data, size);
}
Expand All @@ -456,16 +303,18 @@ void BufferElementCursor::read_data(size_t offset, void* data, size_t size) cons
m_buffer->read_data(offset, data, size);
}

BufferCursor::BufferCursor(ref<TypeLayoutReflection> element_layout, void* data, size_t size)
BufferCursor::BufferCursor(DeviceType device_type, ref<TypeLayoutReflection> element_layout, void* data, size_t size)
: m_element_type_layout(std::move(element_layout))
, m_device_type(device_type)
, m_buffer((uint8_t*)data)
, m_size(size)
, m_owner(false)
{
}

BufferCursor::BufferCursor(ref<TypeLayoutReflection> element_layout, size_t element_count)
BufferCursor::BufferCursor(DeviceType device_type, ref<TypeLayoutReflection> element_layout, size_t element_count)
: m_element_type_layout(std::move(element_layout))
, m_device_type(device_type)
{
m_size = element_count * m_element_type_layout->stride();
m_buffer = new uint8_t[m_size];
Expand All @@ -474,6 +323,7 @@ BufferCursor::BufferCursor(ref<TypeLayoutReflection> element_layout, size_t elem

BufferCursor::BufferCursor(ref<TypeLayoutReflection> element_layout, ref<Buffer> resource, bool load_before_write)
: m_element_type_layout(std::move(element_layout))
, m_device_type(resource->device()->type())
{
m_resource = std::move(resource);
m_size = m_resource->size();
Expand Down
Loading