|
158 | 158 | },
|
159 | 159 | {
|
160 | 160 | "cell_type": "code",
|
161 |
| - "execution_count": 6, |
162 |
| - "metadata": {}, |
163 |
| - "outputs": [ |
164 |
| - { |
165 |
| - "name": "stdout", |
166 |
| - "output_type": "stream", |
167 |
| - "text": [ |
168 |
| - "Cloning into 'deform_conv2d_onnx_exporter'...\n", |
169 |
| - "remote: Enumerating objects: 205, done.\u001b[K\n", |
170 |
| - "remote: Counting objects: 100% (7/7), done.\u001b[K\n", |
171 |
| - "remote: Total 205 (delta 6), reused 6 (delta 6), pack-reused 198 (from 1)\u001b[K\n", |
172 |
| - "Receiving objects: 100% (205/205), 36.21 KiB | 170.00 KiB/s, done.\n", |
173 |
| - "Resolving deltas: 100% (102/102), done.\n" |
174 |
| - ] |
175 |
| - } |
176 |
| - ], |
177 |
| - "source": [ |
178 |
| - "!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter\n", |
179 |
| - "%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .\n", |
180 |
| - "!rm -rf deform_conv2d_onnx_exporter" |
181 |
| - ] |
182 |
| - }, |
183 |
| - { |
184 |
| - "cell_type": "code", |
185 |
| - "execution_count": 7, |
| 161 | + "execution_count": null, |
186 | 162 | "metadata": {},
|
187 | 163 | "outputs": [],
|
188 | 164 | "source": [
|
189 |
| - "with open('deform_conv2d_onnx_exporter.py') as fp:\n", |
190 |
| - " file_lines = fp.read()\n", |
| 165 | + "from torch.onnx.symbolic_helper import parse_args\n", |
| 166 | + "from torch.onnx import register_custom_op_symbolic\n", |
191 | 167 | "\n",
|
192 |
| - "file_lines = file_lines.replace(\n", |
193 |
| - " \"return sym_help._get_tensor_dim_size(tensor, dim)\",\n", |
194 |
| - " '''\n", |
195 |
| - " tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)\n", |
196 |
| - " if tensor_dim_size == None and (dim == 2 or dim == 3):\n", |
197 |
| - " import typing\n", |
198 |
| - " from torch import _C\n", |
199 | 168 | "\n",
|
200 |
| - " x_type = typing.cast(_C.TensorType, tensor.type())\n", |
201 |
| - " x_strides = x_type.strides()\n", |
| 169 | + "@parse_args(\n", |
| 170 | + " \"v\", # arg0: input (tensor)\n", |
| 171 | + " \"v\", # arg1: weight (tensor)\n", |
| 172 | + " \"v\", # arg2: offset (tensor)\n", |
| 173 | + " \"v\", # arg3: mask (tensor)\n", |
| 174 | + " \"v\", # arg4: bias (tensor)\n", |
| 175 | + " \"i\", # arg5: stride_h\n", |
| 176 | + " \"i\", # arg6: stride_w\n", |
| 177 | + " \"i\", # arg7: pad_h\n", |
| 178 | + " \"i\", # arg8: pad_w\n", |
| 179 | + " \"i\", # arg9: dilation_h\n", |
| 180 | + " \"i\", # arg10: dilation_w\n", |
| 181 | + " \"i\", # arg11: groups\n", |
| 182 | + " \"i\", # arg12: deform_groups\n", |
| 183 | + " \"b\", # arg13: some bool\n", |
| 184 | + ")\n", |
| 185 | + "def symbolic_deform_conv_19(\n", |
| 186 | + " g,\n", |
| 187 | + " input,\n", |
| 188 | + " weight,\n", |
| 189 | + " offset,\n", |
| 190 | + " mask,\n", |
| 191 | + " bias,\n", |
| 192 | + " stride_h,\n", |
| 193 | + " stride_w,\n", |
| 194 | + " pad_h,\n", |
| 195 | + " pad_w,\n", |
| 196 | + " dilation_h,\n", |
| 197 | + " dilation_w,\n", |
| 198 | + " groups,\n", |
| 199 | + " deform_groups,\n", |
| 200 | + " maybe_bool,\n", |
| 201 | + "):\n", |
| 202 | + " # Convert them back into lists where needed:\n", |
| 203 | + " strides = [stride_h, stride_w]\n", |
| 204 | + " pads = [pad_h, pad_w, pad_h, pad_w]\n", |
| 205 | + " dilations = [dilation_h, dilation_w]\n", |
202 | 206 | "\n",
|
203 |
| - " tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]\n", |
204 |
| - " elif tensor_dim_size == None and (dim == 0):\n", |
205 |
| - " import typing\n", |
206 |
| - " from torch import _C\n", |
| 207 | + " # If bias is None, you'd do something like:\n", |
| 208 | + " # if bias.node().kind() == \"prim::Constant\" and bias.node()[\"value\"] is None:\n", |
| 209 | + " # bias = g.op(\"Constant\", value_t=torch.tensor([], dtype=torch.float32))\n", |
| 210 | + " #\n", |
| 211 | + " # But from your debug, arg4 is a real tensor of shape [256], so it's not None.\n", |
207 | 212 | "\n",
|
208 |
| - " x_type = typing.cast(_C.TensorType, tensor.type())\n", |
209 |
| - " x_strides = x_type.strides()\n", |
210 |
| - " tensor_dim_size = x_strides[3]\n", |
| 213 | + " # Similarly for mask not being None in your debug, but if you want to handle\n", |
| 214 | + " # a None path, do a check like above.\n", |
| 215 | + "\n", |
| 216 | + " # Construct the official ONNX DeformConv (Opset 19).\n", |
| 217 | + " # 'main' domain => just \"DeformConv\"\n", |
| 218 | + " return g.op(\n", |
| 219 | + " \"DeformConv\",\n", |
| 220 | + " input,\n", |
| 221 | + " weight,\n", |
| 222 | + " offset,\n", |
| 223 | + " bias,\n", |
| 224 | + " mask,\n", |
| 225 | + " strides_i=strides,\n", |
| 226 | + " pads_i=pads,\n", |
| 227 | + " dilations_i=dilations,\n", |
| 228 | + " group_i=groups,\n", |
| 229 | + " offset_group_i=deform_groups,\n", |
| 230 | + " # You can ignore maybe_bool if you don't need it, or pass it as an attribute.\n", |
| 231 | + " )\n", |
211 | 232 | "\n",
|
212 |
| - " return tensor_dim_size\n", |
213 |
| - " ''',\n", |
214 |
| - ")\n", |
215 | 233 | "\n",
|
216 |
| - "with open('deform_conv2d_onnx_exporter.py', mode=\"w\") as fp:\n", |
217 |
| - " fp.write(file_lines)" |
| 234 | + "register_custom_op_symbolic(\n", |
| 235 | + " \"torchvision::deform_conv2d\", # PyTorch JIT/FX name\n", |
| 236 | + " symbolic_deform_conv_19,\n", |
| 237 | + " opset_version=19,\n", |
| 238 | + ")" |
218 | 239 | ]
|
219 | 240 | },
|
220 | 241 | {
|
|
265 | 286 | ],
|
266 | 287 | "source": [
|
267 | 288 | "from torchvision.ops.deform_conv import DeformConv2d\n",
|
268 |
| - "import deform_conv2d_onnx_exporter\n", |
269 | 289 | "\n",
|
270 |
| - "# register deform_conv2d operator\n", |
271 |
| - "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", |
272 | 290 | "\n",
|
273 |
| - "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n", |
| 291 | + "def convert_to_onnx(\n", |
| 292 | + " net, file_name=\"output.onnx\", input_shape=(1024, 1024), device=device\n", |
| 293 | + "):\n", |
274 | 294 | " input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n",
|
275 | 295 | "\n",
|
276 |
| - " input_layer_names = ['input_image']\n", |
277 |
| - " output_layer_names = ['output_image']\n", |
| 296 | + " input_layer_names = [\"input_image\"]\n", |
| 297 | + " output_layer_names = [\"output_image\"]\n", |
278 | 298 | "\n",
|
279 | 299 | " torch.onnx.export(\n",
|
280 | 300 | " net,\n",
|
281 | 301 | " input,\n",
|
282 | 302 | " file_name,\n",
|
283 | 303 | " verbose=False,\n",
|
284 |
| - " opset_version=17,\n", |
| 304 | + " opset_version=20,\n", |
285 | 305 | " input_names=input_layer_names,\n",
|
286 | 306 | " output_names=output_layer_names,\n",
|
| 307 | + " dynamic_axes={\"input_image\": [0]},\n", |
287 | 308 | " )\n",
|
288 |
| - "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)" |
| 309 | + "\n", |
| 310 | + "\n", |
| 311 | + "convert_to_onnx(\n", |
| 312 | + " birefnet,\n", |
| 313 | + " weights_file.replace(\".pth\", \".onnx\"),\n", |
| 314 | + " input_shape=(1024, 1024),\n", |
| 315 | + " device=device,\n", |
| 316 | + ")\n" |
289 | 317 | ]
|
290 | 318 | },
|
291 | 319 | {
|
|
451 | 479 | "name": "python",
|
452 | 480 | "nbconvert_exporter": "python",
|
453 | 481 | "pygments_lexer": "ipython3",
|
454 |
| - "version": "3.9.20" |
| 482 | + "version": "3.12.0" |
455 | 483 | }
|
456 | 484 | },
|
457 | 485 | "nbformat": 4,
|
|
0 commit comments