From 575b22376fba2817dc96d3593c22f191c6fdef02 Mon Sep 17 00:00:00 2001 From: Ky Anh Pham <40734986+itskyf@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:17:52 +0700 Subject: [PATCH] Export with ONNX's DeformConv2d --- tutorials/BiRefNet_pth2onnx.ipynb | 92 +++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 5 deletions(-) diff --git a/tutorials/BiRefNet_pth2onnx.ipynb b/tutorials/BiRefNet_pth2onnx.ipynb index 781fa3d..510e515 100644 --- a/tutorials/BiRefNet_pth2onnx.ipynb +++ b/tutorials/BiRefNet_pth2onnx.ipynb @@ -217,6 +217,81 @@ " fp.write(file_lines)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.onnx.symbolic_helper import parse_args\n", + "from torch.onnx import register_custom_op_symbolic\n", + "\n", + "\n", + "@parse_args(\n", + " \"v\", # arg0: input (tensor)\n", + " \"v\", # arg1: weight (tensor)\n", + " \"v\", # arg2: offset (tensor)\n", + " \"v\", # arg3: mask (tensor)\n", + " \"v\", # arg4: bias (tensor)\n", + " \"i\", # arg5: stride_h\n", + " \"i\", # arg6: stride_w\n", + " \"i\", # arg7: pad_h\n", + " \"i\", # arg8: pad_w\n", + " \"i\", # arg9: dilation_h\n", + " \"i\", # arg10: dilation_w\n", + " \"i\", # arg11: groups\n", + " \"i\", # arg12: deform_groups\n", + " \"b\", # arg13: some bool\n", + ")\n", + "def symbolic_deform_conv_19(\n", + " g,\n", + " input,\n", + " weight,\n", + " offset,\n", + " mask,\n", + " bias,\n", + " stride_h,\n", + " stride_w,\n", + " pad_h,\n", + " pad_w,\n", + " dilation_h,\n", + " dilation_w,\n", + " groups,\n", + " deform_groups,\n", + " maybe_bool,\n", + "):\n", + " # Convert them back into lists where needed:\n", + " strides = [stride_h, stride_w]\n", + " pads = [pad_h, pad_w, pad_h, pad_w]\n", + " dilations = [dilation_h, dilation_w]\n", + "\n", + " # If bias is None, you'd do something like:\n", + " # if bias.node().kind() == \"prim::Constant\" and bias.node()[\"value\"] is None:\n", + " # bias = g.op(\"Constant\", value_t=torch.tensor([], dtype=torch.float32))\n", + " #\n", + " # But from your debug, arg4 is a real tensor of shape [256], so it's not None.\n", + "\n", + " # Similarly for mask not being None in your debug, but if you want to handle\n", + " # a None path, do a check like above.\n", + "\n", + " # Construct the official ONNX DeformConv (Opset 19).\n", + " # 'main' domain => just \"DeformConv\"\n", + " return g.op(\n", + " \"DeformConv\",\n", + " input,\n", + " weight,\n", + " offset,\n", + " bias,\n", + " mask,\n", + " strides_i=strides,\n", + " pads_i=pads,\n", + " dilations_i=dilations,\n", + " group_i=groups,\n", + " offset_group_i=deform_groups,\n", + " # You can ignore maybe_bool if you don't need it, or pass it as an attribute.\n", + " )" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -265,10 +340,16 @@ ], "source": [ "from torchvision.ops.deform_conv import DeformConv2d\n", - "import deform_conv2d_onnx_exporter\n", "\n", - "# register deform_conv2d operator\n", - "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", + "# import deform_conv2d_onnx_exporter\n", + "# # register deform_conv2d operator\n", + "# deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", + "\n", + "register_custom_op_symbolic(\n", + " \"torchvision::deform_conv2d\", # PyTorch JIT/FX name\n", + " symbolic_deform_conv_19,\n", + " opset_version=19,\n", + ")\n", "\n", "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n", " input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n", @@ -281,9 +362,10 @@ " input,\n", " file_name,\n", " verbose=False,\n", - " opset_version=17,\n", + " opset_version=20,\n", " input_names=input_layer_names,\n", " output_names=output_layer_names,\n", + " dynamic_axes={\"input_image\": [0]},\n", " )\n", "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)" ] @@ -451,7 +533,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.20" + "version": "3.12.0" } }, "nbformat": 4,