Skip to content

Restrict ONNX opset to 16 and up #3051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f56c4c9
Update ONNX-IR documentation with more comprehensive description
antimora Mar 4, 2025
47339fa
Fix build issues with data structure changes
antimora Mar 4, 2025
e1557e9
Fix build issues with TensorType structure changes
antimora Mar 7, 2025
1946252
Add static shape handling and rank inference for tensor operations
antimora Apr 16, 2025
02d7b29
Fix clippy warnings
antimora Apr 16, 2025
56de3f1
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 16, 2025
731c6d9
Fix merge issues
antimora Apr 17, 2025
b23f3c6
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 17, 2025
d554562
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 17, 2025
f4e815c
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 18, 2025
d5acc51
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 23, 2025
8728372
Enable unsqueeze with runtime axes values
antimora Apr 24, 2025
55a677a
Fix clippy error
antimora Apr 24, 2025
43af757
Remove default fall back
antimora Apr 24, 2025
03cdbe5
Removed dead code.
antimora Apr 24, 2025
c9b32f2
Removed rank from TensroData
antimora Apr 24, 2025
5ae4685
Removed elem_type from TensorData
antimora Apr 24, 2025
ee7f329
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 24, 2025
a29aba2
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 25, 2025
1aeb4a1
Simplify elem_type match expressions with pattern grouping
antimora Apr 25, 2025
285e361
Add static_shape back
antimora Apr 25, 2025
31c7714
Add restriction for ONNX opset version >= 16
antimora Apr 20, 2025
dabeb29
Add onnx opset upgrade script
antimora Apr 21, 2025
b508df1
Update onnx-model.md
antimora Apr 21, 2025
7693250
Removed onnx files for opsets < 16
antimora Apr 24, 2025
f6c2297
Skip opset upgrades if opset >= 16
antimora Apr 24, 2025
768bb84
Bring back moved onnx file
antimora Apr 24, 2025
f607b89
Fix clippy
antimora Apr 24, 2025
017fca3
Updated opset script per PR feedback
antimora Apr 25, 2025
bfa0388
Reimplement topk onnx and tests for opset16
antimora Apr 25, 2025
662b57d
Merge branch 'main' into restrict-opset-16
antimora Apr 25, 2025
e15a2b3
Update README.md
antimora Apr 25, 2025
1fd49aa
Include infer_shapes step in the upgrade script
antimora Apr 26, 2025
759515b
Merge remote-tracking branch 'upstream/main' into restrict-opset-16
antimora Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 87 additions & 79 deletions burn-book/src/import/onnx-model.md
Original file line number Diff line number Diff line change
@@ -1,72 +1,86 @@
# Importing ONNX Models in Burn

## Table of Contents

1. [Introduction](#introduction)
2. [Why Import Models?](#why-import-models)
3. [Understanding ONNX](#understanding-onnx)
4. [Burn's ONNX Support](#burns-onnx-support)
5. [Step-by-Step Guide](#step-by-step-guide)
6. [Advanced Configuration](#advanced-configuration)
7. [Loading and Using Models](#loading-and-using-models)
8. [Troubleshooting](#troubleshooting)
9. [Examples and Resources](#examples-and-resources)
10. [Conclusion](#conclusion)

## Introduction

As the field of deep learning continues to evolve, the need for interoperability between different
frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust,
recognizes this need and provides robust support for importing models from other popular frameworks.
This section focuses on importing
As deep learning evolves, interoperability between frameworks becomes crucial. Burn, a modern deep
learning framework in Rust, provides robust support for importing models from other popular
frameworks. This section focuses on importing
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep
learning projects.
enabling you to leverage pre-trained models in your Rust-based deep learning projects.

## Why Import Models?

Importing pre-trained models offers several advantages:

1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and
resource-intensive.
1. **Time-saving**: Skip the resource-intensive process of training models from scratch.
2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by
researchers and industry leaders.
3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from
knowledge transfer.
4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework
to another.
4. **Consistency across frameworks**: Maintain consistent performance when moving between
frameworks.

## Understanding ONNX

ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models.
Key features include:
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models
with these key features:

- **Framework agnostic**: ONNX provides a common format that works across various deep learning
- **Framework agnostic**: Provides a common format that works across various deep learning
frameworks.
- **Comprehensive representation**: It captures both the model architecture and trained weights.
- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX
export.
- **Comprehensive representation**: Captures both the model architecture and trained weights.
- **Wide support**: Compatible with popular frameworks like PyTorch, TensorFlow, and scikit-learn.

By using ONNX, you can easily move models between different frameworks and deployment environments.
This standardization allows seamless movement of models between different frameworks and deployment
environments.

## Burn's ONNX Support

Burn takes a unique approach to ONNX import, offering several advantages:
Burn's approach to ONNX import offers unique advantages:

1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for
deep integration with Burn's ecosystem.
2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler,
1. **Native Rust code generation**: Translates ONNX models into Rust source code for deep
integration with Burn's ecosystem.
2. **Compile-time optimization**: Leverages the Rust compiler to optimize the generated code,
potentially improving performance.
3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach
eliminates this dependency.
4. **Trainability**: Imported models can be further trained or fine-tuned using Burn.
5. **Portability**: The generated Rust code can be compiled for various targets, including
WebAssembly and embedded devices.
6. **Any Burn Backend**: The imported models can be used with any of Burn's backends.
3. **No runtime dependency**: Eliminates the need for an ONNX runtime, unlike many other solutions.
4. **Trainability**: Allows imported models to be further trained or fine-tuned using Burn.
5. **Portability**: Enables compilation for various targets, including WebAssembly and embedded
devices.
6. **Backend flexibility**: Works with any of Burn's supported backends.

## ONNX Compatibility

Burn requires ONNX models to use **opset version 16 or higher**. If your model uses an older
version, you'll need to upgrade it using the ONNX version converter.

### Upgrading ONNX Models

There are two simple ways to upgrade your ONNX models to the required opset version:

Option 1: Use the provided utility script:

```
uv run --script https://raw.githubusercontent.com/tracel-ai/burn/refs/heads/main/crates/burn-import/onnx_opset_upgrade.py
```

Option 2: Use a custom Python script:

```python
import onnx
from onnx import version_converter

# Load the ONNX model
model = onnx.load('path/to/your/model.onnx')

# Upgrade to opset version 16
converted_model = version_converter.convert_version(model, 16)

# Save the upgraded model
onnx.save(converted_model, 'upgraded_model.onnx')
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, why are there so many more steps in the onnx_opset_upgrade.py script? Actually, most of the additional code is for debug info or user input, but what about the additional shape inference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the document to include that step. I added it as a safe measure in case the older opset files didn't infer shapes already. But also as a precursor to use inferred node output shapes. For now, I think we can figure out Input/Output ranks from the static shapes.


## Step-by-Step Guide

Let's walk through the process of importing an ONNX model into a Burn project:
Follow these steps to import an ONNX model into your Burn project:

### Step 1: Update `build.rs`

Expand All @@ -90,7 +104,7 @@ fn main() {
}
```

This script uses `ModelGen` to generate Rust code from your ONNX model during the build process.
This generates Rust code from your ONNX model during the build process.

### Step 2: Modify `mod.rs`

Expand All @@ -102,11 +116,9 @@ pub mod my_model {
}
```

This makes the generated model code available in your project.

### Step 3: Use the Imported Model

Now you can use the imported model in your Rust code:
Now you can use the imported model in your code:

```rust
use burn::tensor;
Expand All @@ -116,8 +128,7 @@ use model::my_model::Model;
fn main() {
let device = NdArrayDevice::default();

// Create model instance and load weights from target dir default device.
// (see more load options below in "Loading and Using Models" section)
// Create model instance and load weights from target dir default device
let model: Model<NdArray<f32>> = Model::default();

// Create input tensor (replace with your actual input)
Expand All @@ -132,7 +143,7 @@ fn main() {

## Advanced Configuration

The `ModelGen` struct offers several configuration options:
The `ModelGen` struct provides several configuration options:

```rust
ModelGen::new()
Expand All @@ -144,72 +155,69 @@ ModelGen::new()
.run_from_script();
```

- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
- `record_type`: Defines the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
PrettyJson).
- `half_precision`: Use half-precision (f16) for weights to reduce model size.
- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires
record type `Bincode`.
- `half_precision`: Reduces model size by using half-precision (f16) for weights.
- `embed_states`: Embeds model weights directly in the generated Rust code (requires record type
`Bincode`).

## Loading and Using Models

Depending on your configuration, you can load models in different ways:
Depending on your configuration, you can load models in several ways:

```rust
// Create a new model instance with device. Initializes weights randomly and lazily.
// You can load weights via `load_record` afterwards.
// Create a new model instance with device
// (initializes weights randomly and lazily; load weights via `load_record` afterward)
let model = Model::<Backend>::new(&device);

// Load from a file (must specify weights file in the target output directory or copy it from there).
// File type should match the record type specified in `ModelGen`.
// Load from a file
// (file type should match the record type specified in `ModelGen`)
let model = Model::<Backend>::from_file("path/to/weights", &device);

// Load from embedded weights (if embed_states was true)
let model = Model::<Backend>::from_embedded(&device);

// Load from the out director location and load to default device (useful for testing)
// Load from the output directory with default device (useful for testing)
let model = Model::<Backend>::default();
```

## Troubleshooting

Here are some common issues and their solutions:
Common issues and solutions:

1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the
1. **Unsupported ONNX operator**: Check the
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
You may need to simplify your model or wait for support to be added.
You may need to simplify your model or wait for support.

2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check
that the ONNX file path in `build.rs` is correct.
2. **Build errors**: Ensure your `burn-import` version matches your Burn version and verify the ONNX
file path in `build.rs`.

3. **Runtime errors**: If you get errors when running your model, double-check that your input
tensors match the expected shape and data type of your model.
3. **Runtime errors**: Confirm that your input tensors match the expected shape and data type of
your model.

4. **Performance issues**: If your imported model is slower than expected, try using the
`half_precision` option to reduce memory usage, or experiment with different `record_type`
options.
4. **Performance issues**: Try using the `half_precision` option to reduce memory usage or
experiment with different `record_type` options.

5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR`
directory specified in `build.rs` (usually `target/debug/build/<project>/out`).
5. **Viewing generated files**: Find the generated Rust code and weights in the `OUT_DIR` directory
(usually `target/debug/build/<project>/out`).

## Examples and Resources

For more detailed examples, check out:
For practical examples, check out:

1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)

These examples demonstrate real-world usage of ONNX import in Burn projects.
These demonstrate real-world usage of ONNX import in Burn projects.

## Conclusion

Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage
pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's
safety features. By following this guide, you should be able to seamlessly integrate ONNX models
into your Burn projects, whether for inference, fine-tuning, or as a starting point for further
development.
Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's
performance and Rust's safety features. Following this guide, you can seamlessly integrate ONNX
models into your Burn projects for inference, fine-tuning, or further development.

Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX
operators and improve performance. Stay tuned to the Burn repository for updates and new features!
The `burn-import` crate is actively developed, with ongoing work to support more ONNX operators and
improve performance. Stay tuned to the Burn repository for updates!

---

Expand Down
19 changes: 7 additions & 12 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ fn main() {
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/cast/cast.onnx")
.input("tests/clip/clip_opset16.onnx")
.input("tests/clip/clip_opset7.onnx")
.input("tests/clip/clip.onnx")
.input("tests/concat/concat.onnx")
.input("tests/constant/constant_f32.onnx")
.input("tests/constant/constant_f64.onnx")
Expand All @@ -31,8 +30,7 @@ fn main() {
.input("tests/cos/cos.onnx")
.input("tests/cosh/cosh.onnx")
.input("tests/div/div.onnx")
.input("tests/dropout/dropout_opset16.onnx")
.input("tests/dropout/dropout_opset7.onnx")
.input("tests/dropout/dropout.onnx")
.input("tests/equal/equal.onnx")
.input("tests/erf/erf.onnx")
.input("tests/exp/exp.onnx")
Expand Down Expand Up @@ -97,8 +95,7 @@ fn main() {
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_prod/reduce_prod.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/resize/resize_with_sizes.onnx")
Expand All @@ -116,22 +113,20 @@ fn main() {
.input("tests/softmax/softmax.onnx")
.input("tests/sqrt/sqrt.onnx")
.input("tests/squeeze/squeeze_multiple.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze.onnx")
.input("tests/sub/sub.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/tan/tan.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/top_k/top_k_opset_1.onnx")
.input("tests/topk/topk.onnx")
.input("tests/trilu/trilu_upper.onnx")
.input("tests/trilu/trilu_lower.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_runtime_axes.onnx")
.input("tests/unsqueeze/unsqueeze_like.onnx")
.input("tests/split/split.onnx")
.out_dir("model/")
.run_from_script();
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main():
model.eval()
device = torch.device("cpu")

file_name = "clip_opset16.onnx"
file_name = "clip.onnx"
test_input = torch.rand(6, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)
Expand Down
Binary file not shown.
52 changes: 0 additions & 52 deletions crates/burn-import/onnx-tests/tests/clip/clip_opset7.py

This file was deleted.

Binary file not shown.
2 changes: 1 addition & 1 deletion crates/burn-import/onnx-tests/tests/dropout/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main():
model.eval()
device = torch.device("cpu")

file_name = "dropout_opset16.onnx"
file_name = "dropout.onnx"
test_input = torch.ones(2, 4, 10, 15, device=device)
torch.onnx.export(model, test_input, file_name,
training=torch.onnx.TrainingMode.TRAINING,
Expand Down
Binary file not shown.
Binary file not shown.
Loading