Skip to content

Commit 18498f0

Browse files
committed
Export ONNX with ONNX's DeformConv
1 parent 94caf02 commit 18498f0

File tree

1 file changed

+85
-57
lines changed

1 file changed

+85
-57
lines changed

tutorials/BiRefNet_pth2onnx.ipynb

+85-57
Original file line numberDiff line numberDiff line change
@@ -158,63 +158,84 @@
158158
},
159159
{
160160
"cell_type": "code",
161-
"execution_count": 6,
162-
"metadata": {},
163-
"outputs": [
164-
{
165-
"name": "stdout",
166-
"output_type": "stream",
167-
"text": [
168-
"Cloning into 'deform_conv2d_onnx_exporter'...\n",
169-
"remote: Enumerating objects: 205, done.\u001b[K\n",
170-
"remote: Counting objects: 100% (7/7), done.\u001b[K\n",
171-
"remote: Total 205 (delta 6), reused 6 (delta 6), pack-reused 198 (from 1)\u001b[K\n",
172-
"Receiving objects: 100% (205/205), 36.21 KiB | 170.00 KiB/s, done.\n",
173-
"Resolving deltas: 100% (102/102), done.\n"
174-
]
175-
}
176-
],
177-
"source": [
178-
"!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter\n",
179-
"%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .\n",
180-
"!rm -rf deform_conv2d_onnx_exporter"
181-
]
182-
},
183-
{
184-
"cell_type": "code",
185-
"execution_count": 7,
161+
"execution_count": null,
186162
"metadata": {},
187163
"outputs": [],
188164
"source": [
189-
"with open('deform_conv2d_onnx_exporter.py') as fp:\n",
190-
" file_lines = fp.read()\n",
165+
"from torch.onnx.symbolic_helper import parse_args\n",
166+
"from torch.onnx import register_custom_op_symbolic\n",
191167
"\n",
192-
"file_lines = file_lines.replace(\n",
193-
" \"return sym_help._get_tensor_dim_size(tensor, dim)\",\n",
194-
" '''\n",
195-
" tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)\n",
196-
" if tensor_dim_size == None and (dim == 2 or dim == 3):\n",
197-
" import typing\n",
198-
" from torch import _C\n",
199168
"\n",
200-
" x_type = typing.cast(_C.TensorType, tensor.type())\n",
201-
" x_strides = x_type.strides()\n",
169+
"@parse_args(\n",
170+
" \"v\", # arg0: input (tensor)\n",
171+
" \"v\", # arg1: weight (tensor)\n",
172+
" \"v\", # arg2: offset (tensor)\n",
173+
" \"v\", # arg3: mask (tensor)\n",
174+
" \"v\", # arg4: bias (tensor)\n",
175+
" \"i\", # arg5: stride_h\n",
176+
" \"i\", # arg6: stride_w\n",
177+
" \"i\", # arg7: pad_h\n",
178+
" \"i\", # arg8: pad_w\n",
179+
" \"i\", # arg9: dilation_h\n",
180+
" \"i\", # arg10: dilation_w\n",
181+
" \"i\", # arg11: groups\n",
182+
" \"i\", # arg12: deform_groups\n",
183+
" \"b\", # arg13: some bool\n",
184+
")\n",
185+
"def symbolic_deform_conv_19(\n",
186+
" g,\n",
187+
" input,\n",
188+
" weight,\n",
189+
" offset,\n",
190+
" mask,\n",
191+
" bias,\n",
192+
" stride_h,\n",
193+
" stride_w,\n",
194+
" pad_h,\n",
195+
" pad_w,\n",
196+
" dilation_h,\n",
197+
" dilation_w,\n",
198+
" groups,\n",
199+
" deform_groups,\n",
200+
" maybe_bool,\n",
201+
"):\n",
202+
" # Convert them back into lists where needed:\n",
203+
" strides = [stride_h, stride_w]\n",
204+
" pads = [pad_h, pad_w, pad_h, pad_w]\n",
205+
" dilations = [dilation_h, dilation_w]\n",
202206
"\n",
203-
" tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]\n",
204-
" elif tensor_dim_size == None and (dim == 0):\n",
205-
" import typing\n",
206-
" from torch import _C\n",
207+
" # If bias is None, you'd do something like:\n",
208+
" # if bias.node().kind() == \"prim::Constant\" and bias.node()[\"value\"] is None:\n",
209+
" # bias = g.op(\"Constant\", value_t=torch.tensor([], dtype=torch.float32))\n",
210+
" #\n",
211+
" # But from your debug, arg4 is a real tensor of shape [256], so it's not None.\n",
207212
"\n",
208-
" x_type = typing.cast(_C.TensorType, tensor.type())\n",
209-
" x_strides = x_type.strides()\n",
210-
" tensor_dim_size = x_strides[3]\n",
213+
" # Similarly for mask not being None in your debug, but if you want to handle\n",
214+
" # a None path, do a check like above.\n",
215+
"\n",
216+
" # Construct the official ONNX DeformConv (Opset 19).\n",
217+
" # 'main' domain => just \"DeformConv\"\n",
218+
" return g.op(\n",
219+
" \"DeformConv\",\n",
220+
" input,\n",
221+
" weight,\n",
222+
" offset,\n",
223+
" bias,\n",
224+
" mask,\n",
225+
" strides_i=strides,\n",
226+
" pads_i=pads,\n",
227+
" dilations_i=dilations,\n",
228+
" group_i=groups,\n",
229+
" offset_group_i=deform_groups,\n",
230+
" # You can ignore maybe_bool if you don't need it, or pass it as an attribute.\n",
231+
" )\n",
211232
"\n",
212-
" return tensor_dim_size\n",
213-
" ''',\n",
214-
")\n",
215233
"\n",
216-
"with open('deform_conv2d_onnx_exporter.py', mode=\"w\") as fp:\n",
217-
" fp.write(file_lines)"
234+
"register_custom_op_symbolic(\n",
235+
" \"torchvision::deform_conv2d\", # PyTorch JIT/FX name\n",
236+
" symbolic_deform_conv_19,\n",
237+
" opset_version=19,\n",
238+
")"
218239
]
219240
},
220241
{
@@ -265,27 +286,34 @@
265286
],
266287
"source": [
267288
"from torchvision.ops.deform_conv import DeformConv2d\n",
268-
"import deform_conv2d_onnx_exporter\n",
269289
"\n",
270-
"# register deform_conv2d operator\n",
271-
"deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n",
272290
"\n",
273-
"def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n",
291+
"def convert_to_onnx(\n",
292+
" net, file_name=\"output.onnx\", input_shape=(1024, 1024), device=device\n",
293+
"):\n",
274294
" input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n",
275295
"\n",
276-
" input_layer_names = ['input_image']\n",
277-
" output_layer_names = ['output_image']\n",
296+
" input_layer_names = [\"input_image\"]\n",
297+
" output_layer_names = [\"output_image\"]\n",
278298
"\n",
279299
" torch.onnx.export(\n",
280300
" net,\n",
281301
" input,\n",
282302
" file_name,\n",
283303
" verbose=False,\n",
284-
" opset_version=17,\n",
304+
" opset_version=20,\n",
285305
" input_names=input_layer_names,\n",
286306
" output_names=output_layer_names,\n",
307+
" dynamic_axes={\"input_image\": [0]},\n",
287308
" )\n",
288-
"convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)"
309+
"\n",
310+
"\n",
311+
"convert_to_onnx(\n",
312+
" birefnet,\n",
313+
" weights_file.replace(\".pth\", \".onnx\"),\n",
314+
" input_shape=(1024, 1024),\n",
315+
" device=device,\n",
316+
")\n"
289317
]
290318
},
291319
{
@@ -451,7 +479,7 @@
451479
"name": "python",
452480
"nbconvert_exporter": "python",
453481
"pygments_lexer": "ipython3",
454-
"version": "3.9.20"
482+
"version": "3.12.0"
455483
}
456484
},
457485
"nbformat": 4,

0 commit comments

Comments
 (0)