From 457ec202235e1b2d13310ab3f3ea8e8ba23f8465 Mon Sep 17 00:00:00 2001 From: Jaeyong Park <72537190+nostaljic@users.noreply.github.com> Date: Fri, 31 Jan 2025 13:35:49 +0900 Subject: [PATCH] chore(improve): reduce overhead from global and shared memory --- marlin/marlin_cuda_kernel.cu | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/marlin/marlin_cuda_kernel.cu b/marlin/marlin_cuda_kernel.cu index ae4cef5..03a1a51 100644 --- a/marlin/marlin_cuda_kernel.cu +++ b/marlin/marlin_cuda_kernel.cu @@ -527,18 +527,14 @@ __global__ void Marlin( int row = (threadIdx.x % 32) / 4; if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, - // hence we also use async-copies even though these fetches are not actually asynchronous. #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m - ); + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + int4 c_val = C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)]; + sh[c_sh_wr + c_sh_wr_delta * i] = c_val; + } } - cp_async_fence(); - cp_async_wait<0>(); + __syncthreads(); } #pragma unroll