Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/RendererGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,28 +531,28 @@ Ray::NS::Renderer::InitUNetFilter(const bool alias_memory,
Buffer temp_upload_buf;

if (use_fp16_) {
const int total_count = SetupUNetWeights<uint16_t>(8, nullptr, nullptr);
const int total_count = SetupUNetWeights<uint16_t>(16, nullptr, nullptr);

temp_upload_buf =
Buffer{"UNet Weights CBN Upload", ctx_.get(), eBufType::Upload, uint32_t(total_count * sizeof(uint16_t))};
unet_weights_ =
Buffer{"UNet Weights CBN", ctx_.get(), eBufType::Storage, uint32_t(total_count * sizeof(uint16_t))};

uint16_t *out_weights = (uint16_t *)temp_upload_buf.Map();
SetupUNetWeights(8, &unet_offsets_, out_weights);
SetupUNetWeights(16, &unet_offsets_, out_weights);
temp_upload_buf.Unmap();

CopyBufferToBuffer(temp_upload_buf, 0, unet_weights_, 0, sizeof(uint16_t) * total_count, cmd_buf);
} else {
const int total_count = SetupUNetWeights<float>(8, nullptr, nullptr);
const int total_count = SetupUNetWeights<float>(16, nullptr, nullptr);

temp_upload_buf =
Buffer{"UNet Weights CBN Upload", ctx_.get(), eBufType::Upload, uint32_t(total_count * sizeof(float))};
unet_weights_ =
Buffer{"UNet Weights CBN", ctx_.get(), eBufType::Storage, uint32_t(total_count * sizeof(float))};

float *out_weights = (float *)temp_upload_buf.Map();
SetupUNetWeights(8, &unet_offsets_, out_weights);
SetupUNetWeights(16, &unet_offsets_, out_weights);
temp_upload_buf.Unmap();

CopyBufferToBuffer(temp_upload_buf, 0, unet_weights_, 0, sizeof(float) * total_count, cmd_buf);
Expand Down
96 changes: 48 additions & 48 deletions internal/RendererVK.cpp

Large diffs are not rendered by default.

25 changes: 16 additions & 9 deletions internal/Vk/ContextVK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk

CheckVkPhysicalDeviceFeatures(api_, physical_device_, device_properties_, mem_properties_, graphics_family_index_,
raytracing_supported_, ray_query_supported_, fp16_supported_, int64_supported_,
int64_atomics_supported_, coop_matrix_supported_, pageable_memory_supported_);
int64_atomics_supported_, coop_matrix_size_, pageable_memory_supported_);

// mask out unsupported stages
if (!raytracing_supported_) {
Expand All @@ -227,7 +227,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk

if (!external_ && !InitVkDevice(api_, device_, physical_device_, graphics_family_index_, raytracing_supported_,
ray_query_supported_, fp16_supported_, int64_supported_, int64_atomics_supported_,
coop_matrix_supported_, pageable_memory_supported_, log)) {
coop_matrix_size_[0] != -1, pageable_memory_supported_, log)) {
return false;
}

Expand Down Expand Up @@ -565,7 +565,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
uint32_t &out_graphics_family_index,
bool &out_raytracing_supported, bool &out_ray_query_supported,
bool &out_shader_fp16_supported, bool &out_shader_int64_supported,
bool &out_int64_atomics_supported, bool &out_coop_matrix_supported,
bool &out_int64_atomics_supported, int out_coop_matrix_size[3],
bool &out_pageable_memory_supported) {
api.vkGetPhysicalDeviceProperties(physical_device, &out_device_properties);
api.vkGetPhysicalDeviceMemoryProperties(physical_device, &out_mem_properties);
Expand All @@ -591,16 +591,19 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD

bool acc_struct_supported = false, raytracing_supported = false, ray_query_supported = false,
shader_fp16_supported = false, shader_int64_supported = false, storage_fp16_supported = false,
coop_matrix_supported = false, shader_buf_int64_atomics_supported = false, memory_priority_supported = false,
shader_buf_int64_atomics_supported = false, memory_priority_supported = false,
pageable_memory_supported = false;

int coop_matrix_size[3] = {-1, -1, -1};

{ // check for features support
uint32_t extension_count;
api.vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &extension_count, nullptr);

SmallVector<VkExtensionProperties, 16> available_extensions(extension_count);
api.vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &extension_count, &available_extensions[0]);

bool coop_matrix_supported = false;
for (uint32_t j = 0; j < extension_count; j++) {
const VkExtensionProperties &ext = available_extensions[j];

Expand Down Expand Up @@ -671,7 +674,10 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
for (const VkCooperativeMatrixPropertiesKHR &p : coop_matrix_props) {
if (p.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
p.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
p.MSize == 16 && p.NSize == 8 && p.KSize == 8 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
p.MSize == 16 && p.NSize == 16 && p.KSize == 16 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
coop_matrix_size[0] = 16;
coop_matrix_size[1] = 16;
coop_matrix_size[2] = 16;
found = true;
break;
}
Expand All @@ -685,14 +691,15 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
out_shader_fp16_supported = (shader_fp16_supported && storage_fp16_supported);
out_shader_int64_supported = shader_int64_supported;
out_int64_atomics_supported = shader_buf_int64_atomics_supported;
out_coop_matrix_supported = coop_matrix_supported;
memcpy(out_coop_matrix_size, coop_matrix_size, 3 * sizeof(int));
out_pageable_memory_supported = (memory_priority_supported && pageable_memory_supported);
}

bool Ray::Vk::Context::InitVkDevice(const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
bool enable_fp16, bool enable_int64, bool enable_int64_atomics,
bool enable_coop_matrix, bool enable_pageable_memory, ILog *log) {
const uint32_t graphics_family_index, const bool enable_raytracing,
const bool enable_ray_query, const bool enable_fp16, const bool enable_int64,
const bool enable_int64_atomics, const bool enable_coop_matrix,
const bool enable_pageable_memory, ILog *log) {
VkDeviceQueueCreateInfo queue_create_infos[2] = {{}, {}};
const float queue_priorities[] = {1.0f};

Expand Down
10 changes: 5 additions & 5 deletions internal/Vk/ContextVK.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Context {

bool subgroup_supported_ = false;

bool coop_matrix_supported_ = false;
int coop_matrix_size_[3] = {-1, -1, -1};

bool pageable_memory_supported_ = false;

Expand Down Expand Up @@ -94,7 +94,7 @@ class Context {
bool int64_supported() const { return int64_supported_; }
bool int64_atomics_supported() const { return int64_atomics_supported_; }
bool subgroup_supported() const { return subgroup_supported_; }
bool coop_matrix_supported() const { return coop_matrix_supported_; }
const int *coop_matrix_size() const { return coop_matrix_size_; }

uint32_t supported_stages_mask() const { return supported_stages_mask_; };
bool image_blit_supported() const { return true; }
Expand Down Expand Up @@ -148,15 +148,15 @@ class Context {
private:
static bool InitVkInstance(const Api &api, VkInstance &instance, const char *enabled_layers[],
int enabled_layers_count, int validation_level, ILog *log);
static bool ChooseVkPhysicalDevice(const Api &api, VkPhysicalDevice &physical_device, std::string_view preferred_device,
VkInstance instance, ILog *log);
static bool ChooseVkPhysicalDevice(const Api &api, VkPhysicalDevice &physical_device,
std::string_view preferred_device, VkInstance instance, ILog *log);
static void CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalDevice &physical_device,
VkPhysicalDeviceProperties &device_properties,
VkPhysicalDeviceMemoryProperties &mem_properties,
uint32_t &graphics_family_index, bool &out_raytracing_supported,
bool &out_ray_query_supported, bool &out_shader_fp16_supported,
bool &out_shader_int64_supported, bool &out_int64_atomics_supported,
bool &out_coop_matrix_supported, bool &out_pageable_memory_supported);
int out_coop_matrix_size[3], bool &out_pageable_memory_supported);
static bool InitVkDevice(const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
bool enable_fp16, bool enable_int64, bool enable_int64_atomics, bool enable_coop_matrix,
Expand Down
Loading