diff --git a/.gitignore b/.gitignore index c2d4d2a1c42..9f9df389d30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,13 @@ +# CMAKE +CmakePresets.json + +# MacOS +**/.DS_Store + +build_xcode/ build/ dist/ +framework/ torchvision.egg-info/ torchvision/version.py */**/__pycache__ @@ -10,6 +18,9 @@ torchvision/version.py */**/*~ *~ +#Misc +collect_env.py + docs/build # sphinx-gallery docs/source/auto_examples/ diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/gradle/wrapper/gradle-wrapper.properties index 442d9132ea3..5ef7a1a3a60 100644 --- a/android/gradle/wrapper/gradle-wrapper.properties +++ b/android/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,6 @@ +#Tue Aug 27 15:56:14 CEST 2024 distributionBase=GRADLE_USER_HOME +distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip -zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME diff --git a/test/optest.py b/test/optest.py new file mode 100644 index 00000000000..68caf575dc1 --- /dev/null +++ b/test/optest.py @@ -0,0 +1,15 @@ +# ========================================================= +# BEGIN REPRO SCRIPT +# ========================================================= +import torch +from torch.testing._internal.optests import opcheck + +# Make sure you have loaded the library that contains the op +# via an import or torch.ops.load_library(...) +# op = torch.ops.torchvision.deform_conv2d.default +op = torch.ops.torchvision.roi_align.default +args, kwargs = torch.load("/var/folders/m7/m4jyvbb97ml6nftpw7b6fsk00000gn/T/pytorch_opcheck_safe_to_delete/repro_173109241941725.22.pt") +opcheck(op, args, kwargs, test_utils="test_autograd_registration") +# ========================================================= +# END REPRO SCRIPT +# ========================================================= \ No newline at end of file diff --git a/test/playground/test_mps_import.py b/test/playground/test_mps_import.py new file mode 100644 index 00000000000..8452fe75fde --- /dev/null +++ b/test/playground/test_mps_import.py @@ -0,0 +1,6 @@ +import torch +import torchvision as tv + + + +print(torch.backends.mps.is_available()) diff --git a/test/test_ops.py b/test/test_ops.py index 75940b7e509..289a6444f5e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -929,7 +929,10 @@ def test_batched_nms_implementations(self, seed): class TestDeformConv: dtype = torch.float64 - + mps_dtype = torch.float32 + mps_backward_atol = 2e-2 + mps_backward_eps = 1e-3 + def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) pad_h, pad_w = _pair(padding) @@ -1041,7 +1044,7 @@ def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, ) return DeformConvModuleWrapper(obj) if wrap else obj - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) def test_is_leaf_node(self, device): op_obj = self.make_obj(wrap=True).to(device=device) graph_node_names = get_graph_node_names(op_obj) @@ -1050,12 +1053,17 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) + @pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.opcheck_only_one() - def test_forward(self, device, contiguous, batch_sz, dtype=None): + def test_forward(self, device, contiguous, batch_sz, dtype): dtype = dtype or self.dtype + + if device == "mps" and dtype is torch.float64: + pytest.skip("MPS does not support float64") + x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 @@ -1103,28 +1111,37 @@ def test_wrong_sizes(self): wrong_mask = torch.rand_like(mask[:, :2]) layer(x, offset, wrong_mask) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.opcheck_only_one() - def test_backward(self, device, contiguous, batch_sz): + def test_backward(self, device, contiguous, batch_sz, deterministic=False): + # Batch size of zero fails a check un OperationUtils.mm because tensors with zero as a dimension + # cause the Placeholder::Placeholder to fail. + if device == "mps" and batch_sz == 0: + pytest.skip("MPS does not currently support zero batch size for backpropagation") + + atol = self.mps_backward_atol if device == "mps" else 1e-05 + dtype = self.mps_dtype if device == "mps" else self.dtype + eps = self.mps_backward_eps if device == "mps" else 1e-6 + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( - device, contiguous, batch_sz, self.dtype + device, contiguous, batch_sz, dtype ) def func(x_, offset_, mask_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_ ) - - gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) - + + gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps) + def func_no_mask(x_, offset_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None ) - - gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) + + gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps) @torch.jit.script def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): @@ -1137,7 +1154,7 @@ def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), (x, offset, mask, weight, bias), nondet_tol=1e-5, - fast_mode=True, + fast_mode=True, eps=eps, atol=atol ) @torch.jit.script @@ -1151,7 +1168,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), (x, offset, weight, bias), nondet_tol=1e-5, - fast_mode=True, + fast_mode=True, eps=eps, atol=atol ) @needs_cuda @@ -2035,4 +2052,4 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__ + "::TestDeformConv::test_forward"]) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm new file mode 100644 index 00000000000..7377760b169 --- /dev/null +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -0,0 +1,917 @@ + + +#include +#include +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +const int tkMaxParallelImgs = 32; + +void deformable_im2col(const at::Tensor& input, + const at::Tensor& data_offset, + const at::Tensor& data_mask, + int64_t n_in_channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t out_h, + int64_t out_w, + int64_t parallel_imgs, + int64_t deformable_group, + bool use_mask, + at::Tensor data_col) { + using namespace at::native::mps; + + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); + TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); + + at::TensorArg input_t{input, "input", 1}, + data_offset_t{data_offset, "data_offset", 2}, + data_mask_t{data_mask, "data_mask", 3}; + + at::CheckedFrom c = "deformable_im2col"; + at::checkAllSameGPU(c, {input_t, data_offset_t, data_mask_t}); + at::checkAllSameType(c, {input_t, data_offset_t, data_mask_t}); + + + const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; + + id inputBuffer = getMTLBufferStorage(input); + id data_offsetBuffer = getMTLBufferStorage(data_offset); + id data_maskBuffer = getMTLBufferStorage(data_mask); + id data_colBuffer = getMTLBufferStorage(data_col); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + const std::string kernel = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), + static_cast(512)), + static_cast(4096)), + 1, + 1); + + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input, data_offset, data_mask}); + + id computeEncoder = mpsStream->commandEncoder(); + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; + [computeEncoder setBuffer:data_offsetBuffer offset:data_offset.storage_offset() * data_offset.element_size() atIndex:1]; + [computeEncoder setBuffer:data_maskBuffer offset:data_mask.storage_offset() * data_mask.element_size() atIndex:2]; + [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:3]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:19]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:20]; + + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + +} + +int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +void compute_grad_input( + const at::Tensor& columns, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_im) { + using namespace at::native::mps; + + at::globalContext().alertNotDeterministic("compute_grad_input"); + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + + const int64_t num_kernels = + (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + id columnsBuffer = getMTLBufferStorage(columns); + id offsetBuffer = getMTLBufferStorage(offset); + id maskBuffer = getMTLBufferStorage(mask); + id grad_imBuffer = getMTLBufferStorage(grad_im); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + const std::string kernel = "deformable_col2im_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, offset, mask}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns.storage_offset() * columns.element_size() atIndex:1]; + [computeEncoder setBuffer:offsetBuffer offset:offset.storage_offset() * offset.element_size() atIndex:2]; + [computeEncoder setBuffer:maskBuffer offset:mask.storage_offset() * mask.element_size() atIndex:3]; + [computeEncoder setBuffer:grad_imBuffer + offset:grad_im.storage_offset() * grad_im.element_size() + atIndex:20]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), + 1, + 1); + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +void compute_grad_offset_and_mask( + const at::Tensor& columns, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { + + using namespace at::native::mps; + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + n_offset_grps * parallel_imgs; + + const int64_t offset_channels = 2 * weight_h * weight_w * n_offset_grps; + + id columnsBuffer = getMTLBufferStorage(columns); + id inputBuffer = getMTLBufferStorage(input); + id offsetBuffer = getMTLBufferStorage(offset); + id maskBuffer = getMTLBufferStorage(mask); + id grad_offsetBuffer = getMTLBufferStorage(grad_offset); + id grad_maskBuffer = getMTLBufferStorage(grad_mask); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "deformable_col2im_coord_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, input, offset, mask}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns.storage_offset() * columns.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:2]; + [computeEncoder setBuffer:offsetBuffer offset:offset.storage_offset() * offset.element_size() atIndex:3]; + [computeEncoder setBuffer:maskBuffer offset:mask.storage_offset() * mask.element_size() atIndex:4]; + [computeEncoder setBuffer:grad_offsetBuffer + offset:grad_offset.storage_offset() * grad_offset.element_size() + atIndex:22]; + [computeEncoder setBuffer:grad_maskBuffer + offset:grad_mask.storage_offset() * grad_mask.element_size() + atIndex:23]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&offset_channels length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:19]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:20]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:21]; + + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +std::tuple backward_gradient_inputs( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int64_t out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + if (batch_sz == 0) { + return std::make_tuple(grad_input, grad_offset, grad_mask); + } + + auto columns = at::empty( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + // Separate into blocks + grad_input = grad_input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + grad_offset = grad_offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + grad_mask = grad_mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); + + weight = weight.reshape( + {n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + columns.zero_(); + // Separate into weight groups + for (int64_t g = 0; g < n_weight_grps; g++) { + columns[g] = addmm(columns[g], + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + + compute_grad_offset_and_mask( + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); + + compute_grad_input( + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); + } + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); +} + +at::Tensor backward_gradient_parameters( + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_h = grad_out.size(2); + int64_t out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight); + if (batch_sz == 0) { + return grad_weight; + } + + at::Tensor grad_out_buf = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); + + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_weight = grad_weight.reshape( + {n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + + auto columns = at::empty( + {n_weight_grps, + n_in_channels * weight_w * weight_h / n_weight_grps, + n_parallel_imgs * out_h * out_w}, + input.options()); + + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col( + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + // We need to use addmm instead of addmm_ here to avoid zero values for weight group > 1 + for (int64_t g = 0; g < n_weight_grps; g++) { + grad_weight[g] = + addmm((grad_weight[g].flatten(1)), + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + } + + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + return grad_weight; +} + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + + int batch_sz = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + + offset_c = offset_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view( + {out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view( + {n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + + // The use of in-place .addmm_ has a bug in pytorch, so we use addmm instead + // This needs to be fixed in the future + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = + addmm(out_buf[b][g] + .flatten(1), weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + +std::tuple +deform_conv2d_backward_kernel( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + + TORCH_CHECK(grad_out.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor"); + TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor"); + TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor"); + TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor"); + TORCH_CHECK(grad_out.scalar_type() != at::kHalf, "MPS does not support deform_conv2 backward with float16 inputs."); + + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int64_t batch_sz = input_c.size(0); + const int64_t n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); + + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto value = grad_out_c.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias_c) * value; + + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); +} +} // namespace + + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_kernel)); +} + +} // namespace ops +} // namespace vision + diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index f85546a6c41..6c9f169068a 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -37,6 +37,53 @@ inline void atomic_add_float(device half* data_ptr, const half val) atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast(val), memory_order_relaxed); } +// ********************** deform_conv2d implementation of bilinear_interpolate ********************** +// This implementation is used by the cpu and cuda implementation of the deform_conv2d kernel +// and is needed here in order for the pytest operator test not to fail. + +template +inline scalar_t bilinear_interpolate_deform_conv2d( + constant scalar_t* in, + integer_t height, + integer_t width, + scalar_t h, + scalar_t w, + uint index /* index for debug only*/) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + integer_t h_low = floor(h); + integer_t w_low = floor(w); + integer_t h_high = h_low + 1; + integer_t w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + + + template inline T bilinear_interpolate( constant T* input, @@ -1005,11 +1052,429 @@ kernel void ps_roi_pool_backward( \ constant int64_t & width [[buffer(7)]], \ constant int64_t & pooled_height [[buffer(8)]], \ constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(10)]], \ constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); + + + +/*----------- START OF DEFORM_CONV2D KERNEL IMPLEMENTATION -----------------*/ + +template +kernel void deformable_im2col( + constant scalar_t * input_ptr [[buffer(0)]], + constant scalar_t * offset_ptr [[buffer(1)]], + constant scalar_t * mask_ptr [[buffer(2)]], + device scalar_t * columns_ptr [[buffer(3)]], + constant int64_t & n [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & weight_h [[buffer(7)]], + constant int64_t & weight_w [[buffer(8)]], + constant int64_t & pad_h [[buffer(9)]], + constant int64_t & pad_w [[buffer(10)]], + constant int64_t & stride_h [[buffer(11)]], + constant int64_t & stride_w [[buffer(12)]], + constant int64_t & dilation_h [[buffer(13)]], + constant int64_t & dilation_w [[buffer(14)]], + constant int64_t & batch_sz [[buffer(15)]], + constant int64_t & n_in_channels [[buffer(16)]], + constant int64_t & n_offset_grps [[buffer(17)]], + constant int64_t & out_h [[buffer(18)]], + constant int64_t & out_w [[buffer(19)]], + constant bool & use_mask [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { + const integer_t out_x = index % out_w; + const integer_t out_y = (index / out_w) % out_h; + const integer_t out_b = (index / (out_w * out_h)) % batch_sz; + const integer_t in_c = index / (out_w * out_h * batch_sz); + const integer_t out_c = in_c * weight_h * weight_w; + + integer_t c_per_offset_grp = n_in_channels / n_offset_grps; + const integer_t grp_idx = in_c / c_per_offset_grp; + + columns_ptr += + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + input_ptr += + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + offset_ptr += + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * + out_w; + + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + + // For each element in the filter + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const integer_t mask_idx = i * weight_w + j; + const integer_t offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = + offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + *columns_ptr = + mask_value * bilinear_interpolate_deform_conv2d(input_ptr, height, width, y, x, index); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +kernel void deformable_im2col( \ + constant DTYPE * input_ptr [[buffer(0)]], \ + constant DTYPE * offset_ptr [[buffer(1)]], \ + constant DTYPE * mask_ptr [[buffer(2)]], \ + device DTYPE * columns_ptr [[buffer(3)]], \ + constant int64_t & n [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & weight_h [[buffer(7)]], \ + constant int64_t & weight_w [[buffer(8)]], \ + constant int64_t & pad_h [[buffer(9)]], \ + constant int64_t & pad_w [[buffer(10)]], \ + constant int64_t & stride_h [[buffer(11)]], \ + constant int64_t & stride_w [[buffer(12)]], \ + constant int64_t & dilation_h [[buffer(13)]], \ + constant int64_t & dilation_w [[buffer(14)]], \ + constant int64_t & batch_sz [[buffer(15)]], \ + constant int64_t & n_in_channels [[buffer(16)]], \ + constant int64_t & n_offset_grps [[buffer(17)]], \ + constant int64_t & out_h [[buffer(18)]], \ + constant int64_t & out_w [[buffer(19)]], \ + constant bool & use_mask [[buffer(20)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); + + + +template +kernel void deformable_col2im( + constant int64_t & n [[buffer(0)]], + constant scalar_t * col [[buffer(1)]], + constant scalar_t * offset_ptr [[buffer(2)]], + constant scalar_t * mask_ptr [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & kernel_h [[buffer(7)]], + constant int64_t & kernel_w [[buffer(8)]], + constant int64_t & pad_h [[buffer(9)]], + constant int64_t & pad_w [[buffer(10)]], + constant int64_t & stride_h [[buffer(11)]], + constant int64_t & stride_w [[buffer(12)]], + constant int64_t & dilation_h [[buffer(13)]], + constant int64_t & dilation_w [[buffer(14)]], + constant int64_t & batch_sz [[buffer(15)]], + constant int64_t & n_offset_grps [[buffer(16)]], + constant int64_t & out_h [[buffer(17)]], + constant int64_t & out_w [[buffer(18)]], + constant bool & use_mask [[buffer(19)]], + device scalar_t * grad_im [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { + const integer_t grad_im_numel = width * height * channels * batch_sz; + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { + const integer_t out_x = index % out_w; + const integer_t out_y = (index / out_w) % out_h; + const integer_t b = (index / (out_w * out_h)) % batch_sz; + const integer_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const integer_t i = + (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const integer_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + integer_t c_per_offset_grp = channels / n_offset_grps; + const integer_t offset_grp = c / c_per_offset_grp; + + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const integer_t mask_idx = i * kernel_w + j; + const integer_t offset_idx = 2 * mask_idx; + + const integer_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const integer_t offset_w_ptr = + ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (integer_t dy = -1; dy <= 1; dy++) { + for (integer_t dx = -1; dx <= 1; dx++) { + integer_t yp = (integer_t)y + dy; + integer_t xp = (integer_t)x + dx; + if (0 <= yp && yp < height && 0 <= xp && xp < width && + abs(y - yp) < 1 && abs(x - xp) < 1) { + integer_t grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + // MSL doesn't support at::native::fastAtomicAdd + if (grad_pos >= 0 && grad_pos < grad_im_numel) { + // Atomically add the computed value directly + atomic_add_float(grad_im + grad_pos, static_cast(mask_value * weight * col[index])); + } + } + } + } + } +} + +#define REGISTER_DEFORMABLE_COL2IM_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_col2im_" #DTYPE)]] \ +kernel void deformable_col2im( \ + constant int64_t & n [[buffer(0)]], \ + constant DTYPE * col [[buffer(1)]], \ + constant DTYPE * offset_ptr [[buffer(2)]], \ + constant DTYPE * mask_ptr [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & kernel_h [[buffer(7)]], \ + constant int64_t & kernel_w [[buffer(8)]], \ + constant int64_t & pad_h [[buffer(9)]], \ + constant int64_t & pad_w [[buffer(10)]], \ + constant int64_t & stride_h [[buffer(11)]], \ + constant int64_t & stride_w [[buffer(12)]], \ + constant int64_t & dilation_h [[buffer(13)]], \ + constant int64_t & dilation_w [[buffer(14)]], \ + constant int64_t & batch_sz [[buffer(15)]], \ + constant int64_t & n_offset_grps [[buffer(16)]], \ + constant int64_t & out_h [[buffer(17)]], \ + constant int64_t & out_w [[buffer(18)]], \ + constant bool & use_mask [[buffer(19)]], \ + device DTYPE * grad_im [[buffer(20)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); + + +template +scalar_t get_coordinate_weight( + constant scalar_t* im_data, + index_t height, + index_t width, + scalar_t y, + scalar_t x, + bool is_y_direction) { + index_t y_l = floor(y); + index_t x_l = floor(x); + index_t y_h = y_l + 1; + index_t x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + + + +template +kernel void deformable_col2im_coord( + constant int64_t & n [[buffer(0)]], + constant scalar_t * col_ptr [[buffer(1)]], + constant scalar_t * im_ptr [[buffer(2)]], + constant scalar_t * offset_ptr [[buffer(3)]], + constant scalar_t * mask_ptr [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & weight_h [[buffer(8)]], + constant int64_t & weight_w [[buffer(9)]], + constant int64_t & pad_h [[buffer(10)]], + constant int64_t & pad_w [[buffer(11)]], + constant int64_t & stride_h [[buffer(12)]], + constant int64_t & stride_w [[buffer(13)]], + constant int64_t & dilation_h [[buffer(14)]], + constant int64_t & dilation_w [[buffer(15)]], + constant int64_t & batch_sz [[buffer(16)]], + constant int64_t & offset_channels [[buffer(17)]], + constant int64_t & n_offset_grps [[buffer(18)]], + constant int64_t & out_h [[buffer(19)]], + constant int64_t & out_w [[buffer(20)]], + constant bool & use_mask [[buffer(21)]], + device scalar_t* grad_offset [[buffer(22)]], + device scalar_t* grad_mask [[buffer(23)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + integer_t w = index % out_w; + integer_t h = (index / out_w) % out_h; + integer_t w_w = (index / (out_w * out_h * 2)) % weight_w; + integer_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + integer_t c = (index / (out_w * out_h)) % offset_channels; + integer_t b = index / (out_w * out_h * offset_channels); + + const integer_t offset_grp = c / (2 * weight_h * weight_w); + const integer_t col_step = weight_h * weight_w; + + integer_t c_per_offset_grp = channels / n_offset_grps; + + col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * + out_w * out_h; + im_ptr += + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + + const integer_t offset_c = c - offset_grp * 2 * weight_h * weight_w; + const bool is_y_direction = offset_c % 2 == 0; + + const integer_t c_bound = c_per_offset_grp * weight_h * weight_w; + for (integer_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const integer_t col_pos = + (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + integer_t out_x = col_pos % out_w; + integer_t out_y = (col_pos / out_w) % out_h; + integer_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + integer_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const integer_t mask_idx = i * weight_w + j; + + const integer_t offset_h_ptr = + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); + const integer_t offset_w_ptr = + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate_deform_conv2d(im_ptr, height, width, y, x, index); + } + + im_ptr += height * width; + } + + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const integer_t idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } + } +} + +#define REGISTER_DEFORMABLE_COL2IM_COORD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_col2im_coord_" #DTYPE)]] \ +kernel void deformable_col2im_coord( \ + constant int64_t & n [[buffer(0)]], \ + constant DTYPE * col_ptr [[buffer(1)]], \ + constant DTYPE * im_ptr [[buffer(2)]], \ + constant DTYPE * offset_ptr [[buffer(3)]], \ + constant DTYPE * mask_ptr [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & weight_h [[buffer(8)]], \ + constant int64_t & weight_w [[buffer(9)]], \ + constant int64_t & pad_h [[buffer(10)]], \ + constant int64_t & pad_w [[buffer(11)]], \ + constant int64_t & stride_h [[buffer(12)]], \ + constant int64_t & stride_w [[buffer(13)]], \ + constant int64_t & dilation_h [[buffer(14)]], \ + constant int64_t & dilation_w [[buffer(15)]], \ + constant int64_t & batch_sz [[buffer(16)]], \ + constant int64_t & offset_channels [[buffer(17)]], \ + constant int64_t & n_offset_grps [[buffer(18)]], \ + constant int64_t & out_h [[buffer(19)]], \ + constant int64_t & out_w [[buffer(20)]], \ + constant bool & use_mask [[buffer(21)]], \ + device DTYPE * grad_offset [[buffer(22)]], \ + device DTYPE * grad_mask [[buffer(23)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); + +/* ----------END OF DEFORM_CONV2D KERNELS ----------------------*/ + REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); @@ -1029,6 +1494,12 @@ REGISTER_PS_ROI_POOL_OP(float, int64_t); REGISTER_PS_ROI_POOL_OP(half, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_DEFORMABLE_IM2COL_OP(float, int64_t); +REGISTER_DEFORMABLE_IM2COL_OP(half, int64_t); +REGISTER_DEFORMABLE_COL2IM_OP(float, int64_t); +REGISTER_DEFORMABLE_COL2IM_OP(half, int64_t); +REGISTER_DEFORMABLE_COL2IM_COORD_OP(float, int64_t); +REGISTER_DEFORMABLE_COL2IM_COORD_OP(half, int64_t); )VISION_METAL"); diff --git a/torchvision/installTest.cpp b/torchvision/installTest.cpp new file mode 100644 index 00000000000..6e9f1c90e50 --- /dev/null +++ b/torchvision/installTest.cpp @@ -0,0 +1,5 @@ +#include +#include + + +