diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 63371365655..fbefafc930b 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -3,8 +3,7 @@ #include #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -61,25 +60,25 @@ uint32_t dilation_w_u = static_cast(dilation_w); TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, - "Input channels (", in_channels, + "Input channels (", in_channels, ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")"); TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0, - "Weight tensor's out channels (", weight_c.size(0), + "Weight tensor's out channels (", weight_c.size(0), ") must be divisible by n_weight_grps (", n_weight_grps, ")"); TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w, - "Offset tensor shape[1] is invalid: got ", offset_c.size(1), + "Offset tensor shape[1] is invalid: got ", offset_c.size(1), ", expected ", n_offset_grps * 2 * weight_h * weight_w); TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w, - "Mask tensor shape[1] is invalid: got ", mask_c.size(1), + "Mask tensor shape[1] is invalid: got ", mask_c.size(1), ", expected ", n_offset_grps * weight_h * weight_w); TORCH_CHECK(in_channels % n_offset_grps == 0, - "Input tensor channels (", in_channels, + "Input tensor channels (", in_channels, ") must be divisible by n_offset_grps (", n_offset_grps, ")"); TORCH_CHECK(offset_c.size(0) == batch, "Offset tensor batch size (", offset_c.size(0), ") must match input tensor batch size (", batch, ")"); TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w, - "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), + "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); TORCH_CHECK(!use_mask || mask_c.size(0) == batch, "Mask tensor batch size (", mask_c.size(0), @@ -145,5 +144,4 @@ TORCH_FN(deform_conv2d_forward_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 35c60fa0064..aab4e6b8ed2 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1,9 +1,6 @@ #include -namespace vision { -namespace ops { - -namespace mps { +namespace vision::ops::mps { static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL( @@ -115,15 +112,15 @@ inline T bilinear_interpolate_deformable_conv2d( T v1 = 0; if (y_low >= 0 && x_low >= 0) v1 = input[y_low * width + x_low]; - + T v2 = 0; if (y_low >= 0 && x_high <= width - 1) v2 = input[y_low * width + x_high]; - + T v3 = 0; if (y_high <= height - 1 && x_low >= 0) v3 = input[y_high * width + x_low]; - + T v4 = 0; if (y_high <= height - 1 && x_high <= width - 1) v4 = input[y_high * width + x_high]; @@ -228,7 +225,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], constant float & iou_threshold [[buffer(3)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tid2 [[thread_position_in_threadgroup]]) { - + const uint row_start = tgid.y; const uint col_start = tgid.x; const uint tid = tid2.x; @@ -245,7 +242,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; uint64_t t = 0; uint start = 0; - + if (row_start == col_start) { start = tid + 1; } @@ -309,48 +306,48 @@ kernel void deformable_im2col_kernel( int out_b = (tid / (out_w * out_h)) % batch_size; int in_c = tid / (out_w * out_h * batch_size); int out_c = in_c * weight_h * weight_w; - + int c_per_offset_grp = n_in_channels / n_offset_grps; int grp_idx = in_c / c_per_offset_grp; - + int col_offset = out_c * (batch_size * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x; device T* local_columns_ptr = columns_ptr + col_offset; - + int input_offset = out_b * (n_in_channels * height * width) + in_c * (height * width); constant T* local_input_ptr = input_ptr + input_offset; - + int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; constant T* local_offset_ptr = offset_ptr + offset_offset; - + constant T* local_mask_ptr = nullptr; if (use_mask) { int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; local_mask_ptr = mask_ptr + mask_offset; } - + for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { int mask_index = i * weight_w + j; int offset_index = 2 * mask_index; - + T mask_value = 1; if (use_mask) { mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x]; } - + T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x]; T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x]; - + T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val; T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val; - + T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid); - + *local_columns_ptr = mask_value * interp; - + local_columns_ptr += batch_size * out_h * out_w; } } @@ -584,7 +581,7 @@ kernel void roi_align_backward( atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); - + } // if } // ix } // iy @@ -742,7 +739,6 @@ kernel void roi_pool_backward( if (argmax != -1) { atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); } - } // MPS_1D_KERNEL_LOOP } @@ -1139,7 +1135,6 @@ kernel void ps_roi_pool_backward( atomic_add_float(grad_input + offset + grad_input_index, diff_val); } } - } // MPS_1D_KERNEL_LOOP } @@ -1157,7 +1152,7 @@ kernel void ps_roi_pool_backward( \ constant int64_t & width [[buffer(7)]], \ constant int64_t & pooled_height [[buffer(8)]], \ constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(10)]], \ constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ @@ -1192,6 +1187,4 @@ static id visionPipelineState( return lib.getPipelineStateForFunc(kernel); } -} // namespace mps -} // namespace ops -} // namespace vision +} // namespace vision::ops::mps diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index 5ee9b5cbeae..7982b04a268 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -2,8 +2,7 @@ #include #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -105,5 +104,4 @@ m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 16b711ad5ef..30ba424391f 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -3,8 +3,7 @@ #include "mps_helpers.h" #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -201,5 +200,4 @@ m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 75d0ff4845f..893ab2cc17d 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -3,8 +3,7 @@ #include "mps_helpers.h" #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -195,5 +194,4 @@ m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index d4ed8b43fd2..df047743e20 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -3,8 +3,7 @@ #include "mps_helpers.h" #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -193,5 +192,4 @@ m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index 816d8d70863..99541f23da6 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -3,8 +3,7 @@ #include "mps_helpers.h" #include "mps_kernels.h" -namespace vision { -namespace ops { +namespace vision::ops { namespace { @@ -192,5 +191,4 @@ m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); } -} // namespace ops -} // namespace vision +} // namespace vision::ops