diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md index b02cabfb20..256fc78684 100644 --- a/burn-book/src/import/onnx-model.md +++ b/burn-book/src/import/onnx-model.md @@ -1,72 +1,89 @@ # Importing ONNX Models in Burn -## Table of Contents - -1. [Introduction](#introduction) -2. [Why Import Models?](#why-import-models) -3. [Understanding ONNX](#understanding-onnx) -4. [Burn's ONNX Support](#burns-onnx-support) -5. [Step-by-Step Guide](#step-by-step-guide) -6. [Advanced Configuration](#advanced-configuration) -7. [Loading and Using Models](#loading-and-using-models) -8. [Troubleshooting](#troubleshooting) -9. [Examples and Resources](#examples-and-resources) -10. [Conclusion](#conclusion) - ## Introduction -As the field of deep learning continues to evolve, the need for interoperability between different -frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust, -recognizes this need and provides robust support for importing models from other popular frameworks. -This section focuses on importing +As deep learning evolves, interoperability between frameworks becomes crucial. Burn, a modern deep +learning framework in Rust, provides robust support for importing models from other popular +frameworks. This section focuses on importing [ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn, -enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep -learning projects. +enabling you to leverage pre-trained models in your Rust-based deep learning projects. ## Why Import Models? Importing pre-trained models offers several advantages: -1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and - resource-intensive. +1. **Time-saving**: Skip the resource-intensive process of training models from scratch. 2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by researchers and industry leaders. 3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from knowledge transfer. -4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework - to another. +4. **Consistency across frameworks**: Maintain consistent performance when moving between + frameworks. ## Understanding ONNX -ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models. -Key features include: +ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models +with these key features: -- **Framework agnostic**: ONNX provides a common format that works across various deep learning +- **Framework agnostic**: Provides a common format that works across various deep learning frameworks. -- **Comprehensive representation**: It captures both the model architecture and trained weights. -- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX - export. +- **Comprehensive representation**: Captures both the model architecture and trained weights. +- **Wide support**: Compatible with popular frameworks like PyTorch, TensorFlow, and scikit-learn. -By using ONNX, you can easily move models between different frameworks and deployment environments. +This standardization allows seamless movement of models between different frameworks and deployment +environments. ## Burn's ONNX Support -Burn takes a unique approach to ONNX import, offering several advantages: +Burn's approach to ONNX import offers unique advantages: -1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for - deep integration with Burn's ecosystem. -2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler, +1. **Native Rust code generation**: Translates ONNX models into Rust source code for deep + integration with Burn's ecosystem. +2. **Compile-time optimization**: Leverages the Rust compiler to optimize the generated code, potentially improving performance. -3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach - eliminates this dependency. -4. **Trainability**: Imported models can be further trained or fine-tuned using Burn. -5. **Portability**: The generated Rust code can be compiled for various targets, including - WebAssembly and embedded devices. -6. **Any Burn Backend**: The imported models can be used with any of Burn's backends. +3. **No runtime dependency**: Eliminates the need for an ONNX runtime, unlike many other solutions. +4. **Trainability**: Allows imported models to be further trained or fine-tuned using Burn. +5. **Portability**: Enables compilation for various targets, including WebAssembly and embedded + devices. +6. **Backend flexibility**: Works with any of Burn's supported backends. + +## ONNX Compatibility + +Burn requires ONNX models to use **opset version 16 or higher**. If your model uses an older +version, you'll need to upgrade it using the ONNX version converter. + +### Upgrading ONNX Models + +There are two simple ways to upgrade your ONNX models to the required opset version: + +Option 1: Use the provided utility script: + +``` +uv run --script https://raw.githubusercontent.com/tracel-ai/burn/refs/heads/main/crates/burn-import/onnx_opset_upgrade.py +``` + +Option 2: Use a custom Python script: + +```python +import onnx +from onnx import version_converter, shape_inference + +# Load your ONNX model +model = onnx.load('path/to/your/model.onnx') + +# Convert the model to opset version 16 +upgraded_model = version_converter.convert_version(model, 16) + +# Apply shape inference to the upgraded model +inferred_model = shape_inference.infer_shapes(upgraded_model) + +# Save the converted model +onnx.save(inferred_model, 'upgraded_model.onnx') +``` ## Step-by-Step Guide -Let's walk through the process of importing an ONNX model into a Burn project: +Follow these steps to import an ONNX model into your Burn project: ### Step 1: Update `build.rs` @@ -90,7 +107,7 @@ fn main() { } ``` -This script uses `ModelGen` to generate Rust code from your ONNX model during the build process. +This generates Rust code from your ONNX model during the build process. ### Step 2: Modify `mod.rs` @@ -102,11 +119,9 @@ pub mod my_model { } ``` -This makes the generated model code available in your project. - ### Step 3: Use the Imported Model -Now you can use the imported model in your Rust code: +Now you can use the imported model in your code: ```rust use burn::tensor; @@ -116,8 +131,7 @@ use model::my_model::Model; fn main() { let device = NdArrayDevice::default(); - // Create model instance and load weights from target dir default device. - // (see more load options below in "Loading and Using Models" section) + // Create model instance and load weights from target dir default device let model: Model> = Model::default(); // Create input tensor (replace with your actual input) @@ -132,7 +146,7 @@ fn main() { ## Advanced Configuration -The `ModelGen` struct offers several configuration options: +The `ModelGen` struct provides several configuration options: ```rust ModelGen::new() @@ -144,72 +158,69 @@ ModelGen::new() .run_from_script(); ``` -- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or +- `record_type`: Defines the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or PrettyJson). -- `half_precision`: Use half-precision (f16) for weights to reduce model size. -- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires - record type `Bincode`. +- `half_precision`: Reduces model size by using half-precision (f16) for weights. +- `embed_states`: Embeds model weights directly in the generated Rust code (requires record type + `Bincode`). ## Loading and Using Models -Depending on your configuration, you can load models in different ways: +Depending on your configuration, you can load models in several ways: ```rust -// Create a new model instance with device. Initializes weights randomly and lazily. -// You can load weights via `load_record` afterwards. +// Create a new model instance with device +// (initializes weights randomly and lazily; load weights via `load_record` afterward) let model = Model::::new(&device); -// Load from a file (must specify weights file in the target output directory or copy it from there). -// File type should match the record type specified in `ModelGen`. +// Load from a file +// (file type should match the record type specified in `ModelGen`) let model = Model::::from_file("path/to/weights", &device); // Load from embedded weights (if embed_states was true) let model = Model::::from_embedded(&device); -// Load from the out director location and load to default device (useful for testing) +// Load from the output directory with default device (useful for testing) let model = Model::::default(); ``` ## Troubleshooting -Here are some common issues and their solutions: +Common issues and solutions: -1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the +1. **Unsupported ONNX operator**: Check the [list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). - You may need to simplify your model or wait for support to be added. + You may need to simplify your model or wait for support. -2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check - that the ONNX file path in `build.rs` is correct. +2. **Build errors**: Ensure your `burn-import` version matches your Burn version and verify the ONNX + file path in `build.rs`. -3. **Runtime errors**: If you get errors when running your model, double-check that your input - tensors match the expected shape and data type of your model. +3. **Runtime errors**: Confirm that your input tensors match the expected shape and data type of + your model. -4. **Performance issues**: If your imported model is slower than expected, try using the - `half_precision` option to reduce memory usage, or experiment with different `record_type` - options. +4. **Performance issues**: Try using the `half_precision` option to reduce memory usage or + experiment with different `record_type` options. -5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR` - directory specified in `build.rs` (usually `target/debug/build//out`). +5. **Viewing generated files**: Find the generated Rust code and weights in the `OUT_DIR` directory + (usually `target/debug/build//out`). ## Examples and Resources -For more detailed examples, check out: +For practical examples, check out: 1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference) 2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn) -These examples demonstrate real-world usage of ONNX import in Burn projects. +These demonstrate real-world usage of ONNX import in Burn projects. ## Conclusion -Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage -pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's -safety features. By following this guide, you should be able to seamlessly integrate ONNX models -into your Burn projects, whether for inference, fine-tuning, or as a starting point for further -development. +Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's +performance and Rust's safety features. Following this guide, you can seamlessly integrate ONNX +models into your Burn projects for inference, fine-tuning, or further development. -Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX -operators and improve performance. Stay tuned to the Burn repository for updates and new features! +The `burn-import` crate is actively developed, with ongoing work to support more ONNX operators and +improve performance. Stay tuned to the Burn repository for updates! --- diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index abf9ad71fd..f05d7c810a 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -13,8 +13,7 @@ fn main() { .input("tests/avg_pool2d/avg_pool2d.onnx") .input("tests/batch_norm/batch_norm.onnx") .input("tests/cast/cast.onnx") - .input("tests/clip/clip_opset16.onnx") - .input("tests/clip/clip_opset7.onnx") + .input("tests/clip/clip.onnx") .input("tests/concat/concat.onnx") .input("tests/constant/constant_f32.onnx") .input("tests/constant/constant_f64.onnx") @@ -31,8 +30,7 @@ fn main() { .input("tests/cos/cos.onnx") .input("tests/cosh/cosh.onnx") .input("tests/div/div.onnx") - .input("tests/dropout/dropout_opset16.onnx") - .input("tests/dropout/dropout_opset7.onnx") + .input("tests/dropout/dropout.onnx") .input("tests/equal/equal.onnx") .input("tests/erf/erf.onnx") .input("tests/exp/exp.onnx") @@ -97,8 +95,7 @@ fn main() { .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reduce_min/reduce_min.onnx") .input("tests/reduce_prod/reduce_prod.onnx") - .input("tests/reduce_sum/reduce_sum_opset11.onnx") - .input("tests/reduce_sum/reduce_sum_opset13.onnx") + .input("tests/reduce_sum/reduce_sum.onnx") .input("tests/relu/relu.onnx") .input("tests/reshape/reshape.onnx") .input("tests/resize/resize_with_sizes.onnx") @@ -116,8 +113,7 @@ fn main() { .input("tests/softmax/softmax.onnx") .input("tests/sqrt/sqrt.onnx") .input("tests/squeeze/squeeze_multiple.onnx") - .input("tests/squeeze/squeeze_opset13.onnx") - .input("tests/squeeze/squeeze_opset16.onnx") + .input("tests/squeeze/squeeze.onnx") .input("tests/sub/sub.onnx") .input("tests/sub/sub_int.onnx") .input("tests/sum/sum.onnx") @@ -125,13 +121,12 @@ fn main() { .input("tests/tan/tan.onnx") .input("tests/tanh/tanh.onnx") .input("tests/tile/tile.onnx") - .input("tests/top_k/top_k_opset_1.onnx") + .input("tests/topk/topk.onnx") .input("tests/trilu/trilu_upper.onnx") .input("tests/trilu/trilu_lower.onnx") .input("tests/transpose/transpose.onnx") - .input("tests/unsqueeze/unsqueeze.onnx") - .input("tests/unsqueeze/unsqueeze_opset11.onnx") - .input("tests/unsqueeze/unsqueeze_opset16.onnx") + .input("tests/unsqueeze/unsqueeze_runtime_axes.onnx") + .input("tests/unsqueeze/unsqueeze_like.onnx") .input("tests/split/split.onnx") .out_dir("model/") .run_from_script(); diff --git a/crates/burn-import/onnx-tests/tests/clip/clip_opset16.onnx b/crates/burn-import/onnx-tests/tests/clip/clip.onnx similarity index 86% rename from crates/burn-import/onnx-tests/tests/clip/clip_opset16.onnx rename to crates/burn-import/onnx-tests/tests/clip/clip.onnx index 9be70c1926..c35dc3dfba 100644 Binary files a/crates/burn-import/onnx-tests/tests/clip/clip_opset16.onnx and b/crates/burn-import/onnx-tests/tests/clip/clip.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/clip/clip_opset16.py b/crates/burn-import/onnx-tests/tests/clip/clip.py similarity index 97% rename from crates/burn-import/onnx-tests/tests/clip/clip_opset16.py rename to crates/burn-import/onnx-tests/tests/clip/clip.py index 144dc009df..027bbc907d 100755 --- a/crates/burn-import/onnx-tests/tests/clip/clip_opset16.py +++ b/crates/burn-import/onnx-tests/tests/clip/clip.py @@ -29,7 +29,7 @@ def main(): model.eval() device = torch.device("cpu") - file_name = "clip_opset16.onnx" + file_name = "clip.onnx" test_input = torch.rand(6, device=device) torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=16) diff --git a/crates/burn-import/onnx-tests/tests/clip/clip_opset7.onnx b/crates/burn-import/onnx-tests/tests/clip/clip_opset7.onnx deleted file mode 100644 index 54d65adc6b..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/clip/clip_opset7.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/clip/clip_opset7.py b/crates/burn-import/onnx-tests/tests/clip/clip_opset7.py deleted file mode 100755 index bc56a96e52..0000000000 --- a/crates/burn-import/onnx-tests/tests/clip/clip_opset7.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 - -# used to generate model: clip_opset7.onnx - -import torch -import torch.nn as nn - -# TODO test Int - - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - - def forward(self, x): - x1 = x.clamp(min=0.3) - x2 = x.clamp(min=0.5, max=0.7) - x3 = x.clamp(max=0.8) - return x1, x2, x3 - - -def main(): - - # Set seed for reproducibility - torch.manual_seed(42) - - torch.set_printoptions(precision=8) - - # Export to onnx - model = Model() - model.eval() - device = torch.device("cpu") - - file_name = "clip_opset7.onnx" - test_input = torch.rand(6, device=device) - torch.onnx.export(model, test_input, file_name, - verbose=False, opset_version=7) - - print("Finished exporting model to {}".format(file_name)) - - # Output some test data for use in the test - print("Test input data: {}".format(test_input)) - print("Test input data shape: {}".format(test_input.shape)) - x1, x2, x3 = model.forward(test_input) - print("Test output data shape: {}, {}, {}".format( - x1.shape, x2.shape, x3.shape)) - - print("Test output: {}, {}, {}".format(x1, x2, x3)) - - -if __name__ == '__main__': - main() diff --git a/crates/burn-import/onnx-tests/tests/dropout/dropout_opset16.onnx b/crates/burn-import/onnx-tests/tests/dropout/dropout.onnx similarity index 84% rename from crates/burn-import/onnx-tests/tests/dropout/dropout_opset16.onnx rename to crates/burn-import/onnx-tests/tests/dropout/dropout.onnx index 9b1457a5a6..740a81d906 100644 Binary files a/crates/burn-import/onnx-tests/tests/dropout/dropout_opset16.onnx and b/crates/burn-import/onnx-tests/tests/dropout/dropout.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/dropout/dropout.py b/crates/burn-import/onnx-tests/tests/dropout/dropout.py index a624a3cb0c..9213917dfb 100755 --- a/crates/burn-import/onnx-tests/tests/dropout/dropout.py +++ b/crates/burn-import/onnx-tests/tests/dropout/dropout.py @@ -23,7 +23,7 @@ def main(): model.eval() device = torch.device("cpu") - file_name = "dropout_opset16.onnx" + file_name = "dropout.onnx" test_input = torch.ones(2, 4, 10, 15, device=device) torch.onnx.export(model, test_input, file_name, training=torch.onnx.TrainingMode.TRAINING, diff --git a/crates/burn-import/onnx-tests/tests/dropout/dropout_opset7.onnx b/crates/burn-import/onnx-tests/tests/dropout/dropout_opset7.onnx deleted file mode 100644 index 11eadda126..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/dropout/dropout_opset7.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset13.onnx b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.onnx similarity index 90% rename from crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset13.onnx rename to crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.onnx index 9b40090d6c..99fec06f87 100644 Binary files a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset13.onnx and b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py index 7bb83e6b90..4484f4910b 100755 --- a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py +++ b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum.py @@ -31,8 +31,7 @@ def main(): device = torch.device("cpu") test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device) - torch.onnx.export(model, test_input, "reduce_sum_opset11.onnx", verbose=False, opset_version=11) - torch.onnx.export(model, test_input, "reduce_sum_opset13.onnx", verbose=False, opset_version=13) + torch.onnx.export(model, test_input, "reduce_sum.onnx", verbose=False, opset_version=16) print("Finished exporting model") diff --git a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx b/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx deleted file mode 100644 index cb9f5773b8..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/reduce_sum/reduce_sum_opset11.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.onnx b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.onnx index 04062004bf..028ab77b99 100644 Binary files a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.onnx and b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py index 8b4b7ebcf6..fefac6219a 100644 --- a/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py +++ b/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py @@ -32,10 +32,9 @@ def main(): test_input = torch.randn(3, 4, 1, 5, device=device) # Export to ONNX - torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16) - torch.onnx.export(model, test_input, "squeeze_opset13.onnx", verbose=False, opset_version=13) + torch.onnx.export(model, test_input, "squeeze.onnx", verbose=False, opset_version=16) - print("Finished exporting model to 16 and 13") + print("Finished exporting model") # Output some test data for use in the test output = model(test_input) @@ -50,7 +49,7 @@ def main(): squeeze = helper.make_node(op_type="Squeeze", inputs=["input", "axes"], outputs=["output"], name="SqueezeOp") axes = helper.make_tensor("axes", TensorProto.INT64, dims=[2], vals=[2, 4]) graph = helper.make_graph([squeeze], "SqueezeMultiple", [test_input_ms], [output], [axes]) - opset = helper.make_opsetid("", 13) + opset = helper.make_opsetid("", 16) m = helper.make_model(graph, opset_imports=[opset]) onnx.checker.check_model(m, full_check=True) diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx b/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx index 46760e4469..cf214b9a81 100644 Binary files a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx and b/crates/burn-import/onnx-tests/tests/squeeze/squeeze_multiple.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset13.onnx b/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset13.onnx deleted file mode 100644 index 595ff74489..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset13.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset16.onnx b/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset16.onnx deleted file mode 100644 index 04062004bf..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/squeeze/squeeze_opset16.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 3b1cfd3c36..f4788d51fa 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -24,8 +24,7 @@ include_models!( avg_pool2d, batch_norm, cast, - clip_opset16, - clip_opset7, + clip, concat, constant_f32, constant_f64, @@ -42,54 +41,53 @@ include_models!( cos, cosh, div, - dropout_opset16, - dropout_opset7, + dropout, equal, erf, exp, expand, - expand_tensor, expand_shape, + expand_tensor, flatten, flatten_2d, floor, gather_1d_idx, gather_2d_idx, + gather_elements, gather_scalar, gather_scalar_out, gather_shape, - gather_elements, gelu, gemm, - gemm_non_unit_alpha_beta, gemm_no_c, + gemm_non_unit_alpha_beta, global_avr_pool, graph_multiple_output_tracking, greater, - greater_scalar, greater_or_equal, greater_or_equal_scalar, + greater_scalar, hard_sigmoid, layer_norm, leaky_relu, less, - less_scalar, less_or_equal, less_or_equal_scalar, + less_scalar, linear, log, log_softmax, mask_where, + mask_where_all_scalar, mask_where_broadcast, mask_where_scalar_x, mask_where_scalar_y, - mask_where_all_scalar, matmul, max, maxpool1d, maxpool2d, - min, mean, + min, mul, neg, not, @@ -108,16 +106,15 @@ include_models!( reduce_mean, reduce_min, reduce_prod, - reduce_sum_opset11, - reduce_sum_opset13, + reduce_sum, relu, reshape, - resize_with_sizes, resize_1d_linear_scale, resize_1d_nearest_scale, resize_2d_bicubic_scale, resize_2d_bilinear_scale, resize_2d_nearest_scale, + resize_with_sizes, shape, sigmoid, sign, @@ -125,10 +122,10 @@ include_models!( sinh, slice, softmax, + split, sqrt, + squeeze, squeeze_multiple, - squeeze_opset13, - squeeze_opset16, sub, sub_int, sum, @@ -136,14 +133,12 @@ include_models!( tan, tanh, tile, - top_k_opset_1, - trilu_upper, - trilu_lower, + topk, transpose, - unsqueeze, - unsqueeze_opset11, - unsqueeze_opset16, - split + trilu_lower, + trilu_upper, + unsqueeze_like, + unsqueeze_runtime_axes ); #[cfg(test)] @@ -152,9 +147,7 @@ mod tests { use super::*; - use burn::tensor::{ - Bool, Int, Shape, Tensor, TensorData, Tolerance, cast::ToElement, ops::FloatElem, - }; + use burn::tensor::{Bool, Int, Shape, Tensor, TensorData, Tolerance, ops::FloatElem}; use float_cmp::ApproxEq; @@ -434,27 +427,8 @@ mod tests { } #[test] - fn dropout_opset16() { - let model: dropout_opset16::Model = dropout_opset16::Model::default(); - - // Run the model with ones as input for easier testing - let input = Tensor::::ones([2, 4, 10, 15], &Default::default()); - - let output = model.forward(input); - - let expected_shape = Shape::from([2, 4, 10, 15]); - assert_eq!(output.shape(), expected_shape); - - let output_sum = output.sum().into_scalar(); - - let expected_sum = 1200.0; // from pytorch - - assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); - } - - #[test] - fn dropout_opset7() { - let model: dropout_opset7::Model = dropout_opset7::Model::default(); + fn dropout() { + let model: dropout::Model = dropout::Model::default(); // Run the model with ones as input for easier testing let input = Tensor::::ones([2, 4, 10, 15], &Default::default()); @@ -965,25 +939,9 @@ mod tests { } #[test] - fn reduce_sum_opset11() { - let device = Default::default(); - let model: reduce_sum_opset11::Model = reduce_sum_opset11::Model::new(&device); - - // Run the model - let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); - let (output_scalar, output_tensor, output_value) = model.forward(input.clone()); - let expected_scalar = TensorData::from([39f32]); - let expected = TensorData::from([[[[39f32]]]]); - - output_scalar.to_data().assert_eq(&expected_scalar, true); - output_tensor.to_data().assert_eq(&input.to_data(), true); - output_value.to_data().assert_eq(&expected, true); - } - - #[test] - fn reduce_sum_opset13() { + fn reduce_sum() { let device = Default::default(); - let model: reduce_sum_opset13::Model = reduce_sum_opset13::Model::new(&device); + let model: reduce_sum::Model = reduce_sum::Model::new(&device); // Run the model let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); @@ -1462,10 +1420,10 @@ mod tests { } #[test] - fn clip_opset16() { + fn clip() { // Initialize the model without weights (because the exported file does not contain them) let device = Default::default(); - let model: clip_opset16::Model = clip_opset16::Model::new(&device); + let model: clip::Model = clip::Model::new(&device); // Run the model let input = Tensor::::from_floats( @@ -1496,41 +1454,6 @@ mod tests { output3.to_data().assert_eq(&expected3, true); } - #[test] - fn clip_opset7() { - // Initialize the model without weights (because the exported file does not contain them) - let device = Default::default(); - let model: clip_opset7::Model = clip_opset7::Model::new(&device); - - // Run the model - let input = Tensor::::from_floats( - [ - 0.88226926f32, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ], - &device, - ); - let (output1, output2, output3) = model.forward(input); - let expected1 = TensorData::from([ - 0.88226926f32, - 0.91500396, - 0.38286376, - 0.95930564, - 0.390_448_2, - 0.60089535, - ]); - let expected2 = TensorData::from([0.7f32, 0.7, 0.5, 0.7, 0.5, 0.60089535]); - let expected3 = TensorData::from([0.8f32, 0.8, 0.38286376, 0.8, 0.390_448_2, 0.60089535]); - - output1.to_data().assert_eq(&expected1, true); - output2.to_data().assert_eq(&expected2, true); - output3.to_data().assert_eq(&expected3, true); - } - #[test] fn linear() { let device = Default::default(); @@ -2084,9 +2007,10 @@ mod tests { } #[test] - fn unsqueeze() { + fn unsqueeze_runtime_axes() { let device = Default::default(); - let model: unsqueeze::Model = unsqueeze::Model::new(&device); + let model: unsqueeze_runtime_axes::Model = + unsqueeze_runtime_axes::Model::new(&device); let input_shape = Shape::from([3, 4, 5]); let expected_shape = Shape::from([1, 3, 1, 4, 5, 1]); let input = Tensor::ones(input_shape, &device); @@ -2100,21 +2024,9 @@ mod tests { } #[test] - fn unsqueeze_opset16() { - let device = Default::default(); - let model = unsqueeze_opset16::Model::::new(&device); - let input_shape = Shape::from([3, 4, 5]); - let expected_shape = Shape::from([3, 4, 5, 1]); - let input = Tensor::ones(input_shape, &device); - let output = model.forward(input, 1.0); - assert_eq!(expected_shape, output.0.shape()); - assert_eq!(Shape::from([1]), output.1.shape()); - } - - #[test] - fn unsqueeze_opset11() { + fn unsqueeze_like() { let device = Default::default(); - let model = unsqueeze_opset11::Model::::new(&device); + let model = unsqueeze_like::Model::::new(&device); let input_shape = Shape::from([3, 4, 5]); let expected_shape = Shape::from([3, 4, 5, 1]); let input = Tensor::ones(input_shape, &device); @@ -2269,20 +2181,9 @@ mod tests { } #[test] - fn squeeze_opset16() { + fn squeeze() { let device = Default::default(); - let model = squeeze_opset16::Model::::new(&device); - let input_shape = Shape::from([3, 4, 1, 5]); - let expected_shape = Shape::from([3, 4, 5]); - let input = Tensor::ones(input_shape, &device); - let output = model.forward(input); - assert_eq!(expected_shape, output.shape()); - } - - #[test] - fn squeeze_opset13() { - let device = Default::default(); - let model = squeeze_opset13::Model::::new(&device); + let model = squeeze::Model::::new(&device); let input_shape = Shape::from([3, 4, 1, 5]); let expected_shape = Shape::from([3, 4, 5]); let input = Tensor::ones(input_shape, &device); @@ -2440,22 +2341,29 @@ mod tests { } #[test] - fn top_k_opset_1() { + fn topk() { // Initialize the model let device = Default::default(); - let model = top_k_opset_1::Model::::new(&device); + let model = topk::Model::::new(&device); // Run the model let input = Tensor::::from_floats( - [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], + [ + [0.33669037, 0.12880941, 0.23446237, 0.23033303, -1.12285638], + [-0.18632829, 2.20820141, -0.63799703, 0.46165723, 0.26735088], + [0.53490466, 0.80935723, 1.11029029, -1.68979895, -0.98895991], + ], &device, ); let (values_tensor, indices_tensor) = model.forward(input); // expected results - let expected_values_tensor = - TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); - let expected_indices_tensor = TensorData::from([[3i64, 2, 1], [3, 2, 1]]); + let expected_values_tensor = TensorData::from([ + [0.33669037f32, 0.23446237], + [2.208_201_4, 0.46165723], + [1.110_290_3, 0.809_357_2], + ]); + let expected_indices_tensor = TensorData::from([[0i64, 2], [1, 3], [2, 1]]); values_tensor .to_data() diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.py b/crates/burn-import/onnx-tests/tests/top_k/top_k.py deleted file mode 100644 index dc1579b452..0000000000 --- a/crates/burn-import/onnx-tests/tests/top_k/top_k.py +++ /dev/null @@ -1,76 +0,0 @@ -import numpy as np -import onnx -from onnx import helper, TensorProto - -# Define the input tensor -X = np.array([[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]], dtype=np.float32) - -# Define the value of K -k = 3 -K = np.array([k], dtype=np.int64) -axis = 1 -new_dims = [X.shape[0], k] - -def create_model(op_set_version: int): - input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] - - output_tensors = [ - helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), - helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) - ] - - # Create the TopK node - if op_set_version > 1: - node = helper.make_node( - 'TopK', - inputs=['X', 'K'], - outputs=['Values', 'Indices'], - axis=axis, # Axis along which to find the top K elements - ) - input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) - else: - node = helper.make_node( - 'TopK', - inputs=['X'], - outputs=['Values', 'Indices'], - axis=axis, # Axis along which to find the top K elements - k=k - ) - - # Create the graph - graph = helper.make_graph( - nodes = [node], - name = 'TopKGraph', - inputs = input_tensors, - outputs = output_tensors, - # Uncomment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. - #initializer = [ - # helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), - # helper.make_tensor('K', TensorProto.INT64, [1], [k]), - #] - ) - - # Create the model - model = helper.make_model( - graph, - ir_version=8, - opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] - ) - # Check the model - onnx.checker.check_model(model) - - # Save the model to a file - onnx.save(model, f'top_k_opset_{op_set_version}.onnx') - print(f"Model saved to top_k_opset_{op_set_version}.onnx") - -def main(): - # Uncomment when initializers are supported. - # for op_set_version in [1, 10, 11]: - for op_set_version in [1]: - create_model(op_set_version) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx deleted file mode 100644 index 4eb08a05c9..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/topk/topk.onnx b/crates/burn-import/onnx-tests/tests/topk/topk.onnx new file mode 100644 index 0000000000..4753dc00a0 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/topk/topk.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/topk/topk.py b/crates/burn-import/onnx-tests/tests/topk/topk.py new file mode 100644 index 0000000000..7822b6e13a --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/topk/topk.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn + + +class TopKModel(nn.Module): + def __init__(self, k=1, dim=-1, largest=True, sorted=True): + super(TopKModel, self).__init__() + self.k = k + self.dim = dim + self.largest = largest + self.sorted = sorted + + def forward(self, x): + values, indices = torch.topk( + x, + k=self.k, + dim=self.dim, + largest=self.largest, + sorted=self.sorted + ) + return values, indices + + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + + # Set print options for better precision output + torch.set_printoptions(precision=8) + + # Export TopK Model + k = 2 # Number of top elements to return + dim = 1 # Dimension along which to find top k elements + largest = True # Whether to return largest or smallest elements + sorted = True # Whether to return the elements in sorted order + + model = TopKModel(k=k, dim=dim, largest=largest, sorted=sorted) + model.eval() + device = torch.device("cpu") + + # Generate test input + file_name = "topk.onnx" + test_input = torch.randn(3, 5, device=device) # 3 sequences of 5 elements + torch.onnx.export(model, test_input, file_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + print("Test input data shape: {}".format(test_input.shape)) + values, indices = model.forward(test_input) + print("Test output values shape: {}".format(values.shape)) + print("Test output values: {}".format(values)) + print("Test output indices shape: {}".format(indices.shape)) + print("Test output indices: {}".format(indices)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset16.onnx b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.onnx similarity index 95% rename from crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset16.onnx rename to crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.onnx index 6081df70eb..30c101d16b 100644 Binary files a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset16.onnx and b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_torch.py b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.py similarity index 74% rename from crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_torch.py rename to crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.py index a00dccb00e..da759f5167 100755 --- a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_torch.py +++ b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_like.py @@ -33,18 +33,14 @@ def main(): output = model.forward(*test_input) - torch.onnx.export(model, test_input, "unsqueeze_opset16.onnx", verbose=False, opset_version=16) - torch.onnx.export(model, test_input, "unsqueeze_opset11.onnx", verbose=False, opset_version=11) + torch.onnx.export(model, test_input, "unsqueeze_like.onnx", verbose=False, opset_version=16) print(f"Finished exporting model") # Output some test data for use in the test - print(f"Test input data of ones: {test_input}") print(f"Test input data shape of ones: {test_input[0].shape}") - # output = model.forward(test_input) print(f"Test output data shape: {output[0].shape}") - print(f"Test output: {output}") if __name__ == "__main__": diff --git a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_onnx.py b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_onnx.py deleted file mode 100644 index 434b8687eb..0000000000 --- a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_onnx.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -This module is used to generate Onnx models for testing operator support, and to validate the generated models to ensure they produce spec compliant results. -It's expected that the generated models nodes will be chained Ops of the same type or constants. -Inputs are stored to validate the model results, to debug issues with generating test models, and to potentially generate deserializable test data for the model -""" -from pathlib import Path -from typing import List, Optional, Tuple, NewType, TypeAlias -from numpy.typing import ArrayLike, NDArray -import onnx -import onnxruntime -import numpy as np -from dataclasses import dataclass, field, InitVar - -# TODO: need to come up with some examples of valid inputs -TensorMap: TypeAlias = ( - dict[str, ArrayLike] | List[ArrayLike] | str | Tuple[str, ArrayLike] -) -"""TypeAlias for the inputs to OnnxGraphBuilder. if it's a dictionary, the keys are the names of the inputs and the values are the input data. -If it's a list, for non tuple elements, names are autogenerated. For tuple elements, the first element is the name of the input and the second element is the value of the input. - -Note: - -The ">" char can be used to indicate the output of the previous node if the previous node has a single output. Example: - -[">", np.array([1,2,3])] - -""" - -NodeOutput = NewType("NodeOutput", str) - - -def validate_sequence(arr: List | Tuple) -> NDArray: - """Function to validate that all elements in a sequence are of the same type. - - - Args: - arr (List | Tuple): Sequence to validate - - Raises: - ValueError: raised if the elements in the sequence are not of the same type. - """ - el_type = type(arr[0]) - for el in arr: - if type(el) != el_type: - raise ValueError( - f"Expected all elements to be of type {el_type} but got {type(el)}" - ) - return np.array(arr) - - -var_counter = 0 -op_counter = 0 - - -def validate_input( - input_data: TensorMap, - value_names: set[NodeOutput], - out_names: set[str], -) -> dict[str, ArrayLike | NodeOutput]: - """helper function to validate inputs to a node - - Args: - input_data (dict[str, ArrayLike | str]): The inputs to the node that need to be validated - value_names (set[NodeOutput]): the list of names of outputs from nodes already in the graph - out_names (set[str]): set of names of onnx values produced by the Node, used to make sure that the input names don't conflict with the output names. - - Raises: - ValueError: _description_ - TypeError: _description_ - - Returns: - dict[str, ArrayLike | NodeOutput]: _description_ - """ - - def kv_check(value_names, out_names, res, var_name, var_value): - if var_name in out_names: - raise ValueError( - f"Tensor Input {var_name} cannot have the same name as a tensors output" - ) - if isinstance(var_value, str): - if var_value not in value_names: - raise TypeError( - f"NodeOutput {var_value} not found in outputs. Please provide the output of a previous node as an input." - ) - res[var_name] = NodeOutput(var_value) - else: - res[var_name] = var_value - - def next_var(): - global var_counter - var_counter += 1 - return f"var_{var_counter}" - - res: dict[str, ArrayLike | NodeOutput] = {} - match input_data: - case dict(): - for k, v in input_data.items(): - kv_check(value_names, out_names, res, k, v) - case list(): - for item in input_data: - if isinstance(item, str): - if item == ">": - res[item] = NodeOutput(item) - if item not in value_names: - raise TypeError( - f"NodeOutput {item} not found in outputs. Please provide the output of a previous node as an input." - ) - elif isinstance(item, tuple): - # until I come up with something better - if isinstance(item[0], str): - ( - kv_check( - value_names, - out_names, - res, - next_var(), - item[1], - ), - ) - - else: # it's an arraylike which will be validated later - res[next_var()] = item - elif isinstance(item, np.ndarray | list) or np.isscalar(item): - res[f"var_{var_counter}"] = item - - return res - - -def _get_tensor(name: str, arr: np.ndarray): - tensor_type = onnx.helper.np_dtype_to_tensor_dtype(arr.dtype) - return onnx.helper.make_tensor_value_info(name, tensor_type, arr.shape) - - -def _get_scalar(name: str, scalar_type: int): - return onnx.helper.make_value_info(name, scalar_type) # type: ignore - - -def make_onnx_types(name: str, input_data: ArrayLike) -> onnx.ValueInfoProto: - """Function to map inputs to OnnxOpData to onnx types - - Args: - name (str): The name of the input - v (ArrayLike): The input data - - Raises: - ValueError: If the input data is not a supported type (np.ndarray, list, tuple, int, float, bool) then an error is raised. - - Returns: - onnx.TensorProto | onnx.ValueInfoProto: returns a tensor proto or value info proto based on the input data. - """ - match input_data: - case np.ndarray(): - return _get_tensor(name, input_data) # type: ignore - case list() | tuple(): - return _get_tensor(name, validate_sequence(input_data)) # type: ignore - case int(): - return _get_scalar(name, onnx.ValueInfoProto.INT64) - case float(): - return _get_scalar(name, onnx.ValueInfoProto.FLOAT) - case bool(): - return _get_scalar(name, onnx.ValueInfoProto.BOOL) - case _: - raise ValueError(f"Unsupported type: {type(input_data)}") - - -@dataclass -class OnnxConst: - name: str - value: InitVar[ArrayLike] - # __value: NDArray = field(init=False, default_factory=np.array) # type: ignore - __tensor: onnx.TensorProto = field(init=False, default_factory=onnx.TensorProto) # type: ignore - - def __post_init__(self, value): - if np.isscalar(value) or isinstance(value, (list, tuple)): - value = np.array(value) - self.__tensor = onnx.helper.make_tensor( - name=self.name, - data_type=onnx.helper.np_dtype_to_tensor_dtype(value.dtype), - dims=value.shape, - vals=value.flatten(), - ) - - def to_onnx(self): - return onnx.helper.make_node( - "Constant", - inputs=[], - outputs=[self.name], - value=self.__tensor, - ) - - def to_ndarray(self): - return np.frombuffer( - self.__tensor.raw_data, - onnx.helper.tensor_dtype_to_np_dtype(self.__tensor.data_type), - ).reshape(self.__tensor.dims) - - -@dataclass -class OnnxOpData: - """helper for generating and validating nodes for testing operator support. - - Attributes: - name (str): The name of the operator. Must match the name of the operator in onnx. - inputs (dict[str, ArrayLike]): The inputs to the operator - output (dict[str, ArrayLike]): The expected output of the operator - """ - - op_name: str - inputs: dict[str, ArrayLike | NodeOutput] - output: dict[str, ArrayLike] - count: int = field(init=False) - - def __post_init__(self): - global op_counter - op_counter += 1 - self.count = op_counter - - @property - def input_names(self) -> List[str]: - return list(self.inputs.keys()) - - @property - def output_names(self) -> List[str]: - return list(self.output.keys()) - - @property - def output_vals(self) -> List[onnx.ValueInfoProto]: - return [make_onnx_types(k, v) for k, v in self.output.items()] - - @property - def input_vals(self) -> List[onnx.TensorProto | onnx.ValueInfoProto]: - return [ - make_onnx_types(k, v) - for k, v in self.inputs.items() - if type(v) != NodeOutput - ] - - def to_onnx(self): - return onnx.helper.make_node( - self.op_name, - inputs=self.input_names, - outputs=self.output_names, - name=f"{self.op_name}{self.count}", - ) - - -def _get_path( - graph_name: str, path: Optional[Path | str] = None, ext: str = ".onnx" -) -> str: - out_path = Path(".") / f"{graph_name.lower()}{ext}" - if path: - if (tmp := Path(path)).is_dir(): - out_path = tmp / f"{graph_name.lower()}{ext}" - elif tmp.suffix != ".onnx": - raise ValueError( - f"Provide path {path} must include the model name and end with .onnx extension" - ) - - return str(out_path) - - -@dataclass -class OnnxGraphBuilder: - name: InitVar[str] - inputs: InitVar[dict[str, ArrayLike]] - output: InitVar[dict[str, ArrayLike]] - rhs_constant: InitVar[bool] = field(default=False) - - graph_name: str = field(init=False) - value_map: dict[str, ArrayLike] = field(init=False) - node_counter: int = field(init=False, default=0) - output_set: set[NodeOutput] = field(init=False, default_factory=set) - nodes: List[OnnxOpData] = field(init=False, default_factory=list) - constants: dict[str, OnnxConst] = field(default_factory=dict) - - def __post_init__(self, name, inputs, output, rhs_constant): - self.graph_name = f"{name}" - if rhs_constant: - self.first_idx = 1 - out_names = set(output.keys()) - validated_input = validate_input(inputs, self.output_set, out_names) - - for k in output: - out_name = NodeOutput(k) - if out_name in self.output_set: - raise ValueError(f"Output {k} already exists in the graph") - self.output_set.add(out_name) - - if rhs_constant: - const_name = list(validated_input.keys())[1] - const_val = validated_input[const_name] - self.constants[const_name] = OnnxConst(const_name, const_val) - const_out = NodeOutput(const_name) - if const_out in self.output_set: - raise ValueError( - f"Name for rhs const {const_name} already exists in the graph" - ) - self.output_set.add(const_out) - - self.nodes.append(OnnxOpData(name, inputs, output)) - - def add_node( - self, - name: str, - inputs: dict[str, ArrayLike], - output: dict[str, ArrayLike], - rhs_constant: bool = False, - ): - out_names = set(output.keys()) - validated_input = validate_input(inputs, self.output_set, out_names) - if ">" in validated_input: - # only works if the previous node has a single output - if len((prev_out := self.nodes[-1].output_names)) != 1: - raise KeyError( - "Previous node has more than one output. Please specify the output to use" - ) - validated_input[prev_out[0]] = NodeOutput(prev_out[0]) - - for k in output: - out_name = NodeOutput(k) - if out_name in self.output_set: - raise ValueError(f"Output {k} already exists in the graph") - self.output_set.add(out_name) - - if rhs_constant: - const_name = list(validated_input.keys())[1] - const_val = validated_input[const_name] - self.constants[const_name] = OnnxConst(const_name, const_val) - const_out = NodeOutput(const_name) - if const_out in self.output_set: - raise ValueError( - f"Name for rhs const {const_name} already exists in the graph" - ) - self.output_set.add(const_out) - - self.nodes.append(OnnxOpData(name, inputs, output)) - - @property - def graph_nodes(self): - res = [const.to_onnx() for const in self.constants.values()] - res.extend([node.to_onnx() for node in self.nodes]) - return res - - def get_graph_inputs(self) -> List[onnx.ValueInfoProto]: - res: List[onnx.ValueInfoProto] = [] - for node in self.nodes: - res.extend( - make_onnx_types(k, v) - for k, v in node.inputs.items() - if k not in self.constants and NodeOutput(k) not in self.output_set - ) - return res - - def get_graph_outputs(self): - return self.nodes[-1].output_vals - - def get_output_names(self): - return self.nodes[-1].output_names - - def get_expected_outputs(self): - return list(self.nodes[-1].output.values()) - - def make_onnx_graph(self) -> onnx.GraphProto: - """Create a graph with a single node for testing. - - Args: - op_inputs: The input tensor to the node.""" - - graph: onnx.GraphProto = onnx.helper.make_graph( - self.graph_nodes, - self.graph_name, - inputs=self.get_graph_inputs(), # type: ignore - outputs=self.get_graph_outputs(), # type: ignore - ) - - return graph - - def save_model(self, path: Optional[Path | str] = None): - """Converts the generated graph to an onnx model and saves it to a file. - - Args: - path (Optional[Path | str], optional): desired path to the output. if unspecified, defaults to {op_name}.onnx - - Raises: - ValueError: If you provide a path and it doesn't end with .onnx then an error is raised - """ - model = onnx.helper.make_model(self.make_onnx_graph()) - out_path = _get_path(self.graph_name, path) - onnx.save(model, out_path) - print(f"Model saved to {out_path}") - - def get_sess_inputs(self): - res = {} - for node in self.nodes: - for k, v in node.inputs.items(): - if k not in self.constants and NodeOutput(k) not in self.output_set: - res[k] = v - return res - - def validate_model(self, model_path: Optional[Path | str] = None): - """Loads the generated model and runs it with the provided inputs to validate the output. - More of a sanity check than anything else. - - Returns: - Outputs (Any): returns the outputs of the model in case there is a need to inspect them. - """ - model_path = _get_path(self.graph_name, model_path) - sess = onnxruntime.InferenceSession(model_path) - sess_inputs = [inp.name for inp in sess.get_inputs()] - sess_outputs = sess.run( - self.get_output_names(), - self.get_sess_inputs(), - ) - - for i, out in enumerate( - self.get_expected_outputs(), - ): - assert np.allclose(out, sess_outputs[i]) - print("Output is the same as expected. Test passed.") - return sess_outputs - - def model_to_txt(self, path: Optional[Path | str] = None): - """load the generated model and save it to a txt file for debugging purposes. - - Args: - path (Optional[Path | str], optional): desired path to the output. if unspecified, defaults to {op_name}.txt - in the current directory. Defaults to None. - - Raises: - ValueError: If you provide a path and it doesn't end with .txt then an error is raised - """ - model = onnx.helper.make_model(self.make_onnx_graph()) - out_path = _get_path(self.graph_name, path, ".txt") - with open(out_path, "w") as f: - f.write(str(model)) - print(f"Model saved to {out_path}") - - -if __name__ == "__main__": - const_axes = [0, 4] - axis = [1] - x = np.array(np.random.randn(3, 4, 5)) - y = np.expand_dims(x, axis=const_axes) - z = np.expand_dims(y, axis=axis) - - if y.shape != (1, 3, 4, 5, 1): - raise ValueError(f"Expected shape (1,3,4,5,1) but got {y.shape}") - if z.shape != (1, 1, 3, 4, 5, 1): - raise ValueError(f"Expected shape (1,1,3,4,5,1) but got {z.shape}") - - data = OnnxGraphBuilder( - "Unsqueeze", - {"x": x, "axes": const_axes}, - {"y": y}, - rhs_constant=True, - ) - data.add_node("Unsqueeze", {"y": "y", "axis": axis}, {"z": z}) - - # data.model_to_txt() - result = data.validate_model() - - assert np.allclose(result[0], data.get_expected_outputs()[0]) - print("Test passed") diff --git a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset11.onnx b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset11.onnx deleted file mode 100644 index 858a41fd8d..0000000000 Binary files a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_opset11.onnx and /dev/null differ diff --git a/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze.onnx b/crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_runtime_axes.onnx similarity index 100% rename from crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze.onnx rename to crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_runtime_axes.onnx diff --git a/crates/burn-import/onnx_opset_upgrade.py b/crates/burn-import/onnx_opset_upgrade.py new file mode 100755 index 0000000000..419374138a --- /dev/null +++ b/crates/burn-import/onnx_opset_upgrade.py @@ -0,0 +1,104 @@ +#!/usr/bin/env -S uv run --script + +# /// script +# dependencies = [ +# "onnx-weekly==1.19.0.dev20250419", +# ] +# /// +# +# Learn more about Astral's UV tool at +# https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies + +import os +import sys +import onnx +from onnx import shape_inference +from onnx import version_converter + + +def validate_model_path(model_path): + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + +def load_onnx_model(model_path): + try: + model = onnx.load(model_path) + except Exception as e: + raise RuntimeError(f"Failed to load ONNX model: {str(e)}") + + try: + onnx.checker.check_model(model) + print("Model loaded successfully.") + except Exception as e: + raise RuntimeError(f"Model validation failed: {str(e)}") + + return model + + +def print_opset_version(model): + try: + print("Opset version:", model.opset_import[0].version) + except (IndexError, AttributeError): + print("Warning: Could not determine opset version") + + +def upgrade_model(model): + try: + current_opset = model.opset_import[0].version + if current_opset >= 16: + print(f"Current opset version {current_opset} is already >= 16, skipping upgrade.") + return model + + upgraded_model = version_converter.convert_version(model, 16) + print("Model upgraded to opset 16.") + return upgraded_model + except Exception as e: + raise RuntimeError(f"Failed to upgrade model to opset 16: {str(e)}") + + +def apply_shape_inference(upgraded_model): + try: + inferred_model = shape_inference.infer_shapes(upgraded_model) + print("Model shape inference applied.") + return inferred_model + except Exception as e: + print(f"Warning: Shape inference partially applied: {str(e)}") + return upgraded_model + + +def save_model(inferred_model, output_path): + try: + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + onnx.save(inferred_model, output_path) + print(f"Model saved to: {output_path}") + except Exception as e: + raise RuntimeError(f"Failed to save model: {str(e)}") + +def main(): + # Get input path from user prompt + model_path = input("Enter the path to the input ONNX model: ") + validate_model_path(model_path) + + # Process the model + model = load_onnx_model(model_path) + print_opset_version(model) + upgraded_model = upgrade_model(model) + inferred_model = apply_shape_inference(upgraded_model) + + # Get output path from user prompt + default_output = model_path.replace('.onnx', '_opset16.onnx') + output_path_input = input(f"Enter the path to save the output ONNX model (press Enter for default '{default_output}'): ") + output_path = output_path_input if output_path_input else default_output + + save_model(inferred_model, output_path) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nOperation cancelled by user.") + sys.exit(1) diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md index 876ba7296c..69e2cf8085 100644 --- a/crates/onnx-ir/README.md +++ b/crates/onnx-ir/README.md @@ -1,7 +1,42 @@ # ONNX-IR -ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation that can be used to generate code for various ML/DL frameworks. It's part of the Burn project, with key features including ONNX model parsing, rank inference, and node remapping. The crate supports converting ONNX models to Burn graphs and includes utilities for handling constants and graph transformations. +ONNX-IR is a pure Rust library for parsing ONNX models into an intermediate representation that can +be used to generate code for various ML/DL frameworks. It's part of the Burn project, with key +features including ONNX model parsing, rank inference, and node remapping. The crate supports +converting ONNX models to Burn graphs and includes utilities for handling constants and graph +transformations. -For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) +For a full list of currently supported operators, please check +[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) -To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs). \ No newline at end of file +## ONNX Compatibility + +This library requires ONNX models to use **opset version 16 or higher**. If your model uses an older +opset version, you'll need to upgrade it using the ONNX version converter. + +### Upgrading ONNX Models + +You can upgrade your ONNX models using the following Python script: + +```python +import onnx +from onnx import version_converter, shape_inference + +# Load your ONNX model +model = onnx.load('path/to/your/model.onnx') + +# Convert the model to opset version 16 +upgraded_model = version_converter.convert_version(model, 16) + +# Apply shape inference to the upgraded model +inferred_model = shape_inference.infer_shapes(upgraded_model) + +# Save the converted model +onnx.save(inferred_model, 'upgraded_model.onnx') +``` + +For a full list of currently supported operators, please check +[here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) + +To see how to use this for generating burn graphs, see +[here](crates/burn-import/src/onnx/to_burn.rs). diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index f97a0b0faa..29e59b50ae 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -5,6 +5,7 @@ use std::{ }; use crate::node_remap::remap_node_type; +use crate::util::verify_opsets; use super::{ coalesce::coalesce, @@ -18,7 +19,7 @@ use super::rank_inference::rank_inference; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 15] = [ +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ NodeType::BatchNormalization, NodeType::Clip, NodeType::Conv1d, @@ -26,16 +27,20 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 15] = [ NodeType::Dropout, NodeType::Expand, NodeType::OneHot, + NodeType::ReduceSum, NodeType::Reshape, NodeType::Resize, - NodeType::Unsqueeze, - NodeType::ReduceSum, NodeType::Slice, - NodeType::Squeeze, NodeType::Split, + NodeType::Squeeze, + NodeType::TopK, NodeType::Trilu, + NodeType::Unsqueeze, ]; +/// Minimum required ONNX opset version +pub const MIN_OPSET_VERSION: i64 = 16; + #[derive(Debug, Clone)] pub(crate) enum IOEntry { In(usize), @@ -331,29 +336,44 @@ impl OnnxGraphBuilder { } } -/// Open an onnx file and convert it to a Graph (intermediate representation) +/// Parses an ONNX model file and converts it to an intermediate representation. +/// +/// This function reads an ONNX model from the specified path, validates its opset version, +/// and transforms it into our internal graph representation for further processing. /// /// # Arguments /// -/// * `onnx_path` - Path to the onnx file +/// * `onnx_path` - Path to the ONNX model file /// /// # Returns /// -/// * `OnnxGraph` - The graph representation of the onnx file +/// * `OnnxGraph` - The internal graph representation of the ONNX model /// /// # Panics /// -/// * If the file cannot be opened -/// * If the file cannot be parsed -/// * If the nodes are not topologically sorted +/// * If the file cannot be opened or read +/// * If the ONNX model cannot be parsed +/// * If the model uses an unsupported opset version (must be >= MIN_OPSET_VERSION) +/// * If the nodes in the graph are not topologically sorted pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { log::info!("Parsing ONNX file: {}", onnx_path.display()); // Open the file - let mut file = File::open(onnx_path).expect("Unable to open file"); + let mut file = File::open(onnx_path) + .unwrap_or_else(|_| panic!("Unable to open file: {}", onnx_path.display())); let onnx_model: ModelProto = Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + // Check opset versions - must be >= MIN_OPSET_VERSION + if !verify_opsets(&onnx_model.opset_import, MIN_OPSET_VERSION) { + panic!( + "Unsupported ONNX opset version. This implementation requires opset {} or higher. \ + Please upgrade your model using the ONNX shape inference tool. \ + See documentation (https://burn.dev/burn-book/import/onnx-model.html) for details.", + MIN_OPSET_VERSION + ); + } + // ONNX nodes must be topologically sorted per spec: // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs debug_assert!( @@ -369,6 +389,20 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { ); log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + + // Debug information about opset versions + for opset in &onnx_model.opset_import { + log::debug!( + "Opset domain: {:?}, version: {:?}", + if opset.domain.is_empty() { + "" + } else { + &opset.domain + }, + opset.version + ); + } + let builder = OnnxGraphBuilder::default(); let graph = builder.build(&onnx_model); diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs index 04de5add0e..6843b3c9be 100644 --- a/crates/onnx-ir/src/util.rs +++ b/crates/onnx-ir/src/util.rs @@ -1,4 +1,5 @@ use crate::ir::{ArgType, Node}; +use crate::protos::OperatorSetIdProto; pub fn shape_config(curr: &Node) -> (usize, usize) { if curr.inputs.len() != 1 { @@ -37,3 +38,46 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { (start_dim as usize, end_dim as usize) } + +/// Check whether the provided operator set version is supported. +/// +/// # Arguments +/// +/// * `opset` - The operator set to check +/// * `min_version` - The minimum supported version +/// +/// # Returns +/// +/// * `bool` - True if the opset version is supported, false otherwise +/// +/// # Panics +/// +/// * If the domain is not the empty ONNX domain +pub fn check_opset_version(opset: &OperatorSetIdProto, min_version: i64) -> bool { + // For now, only empty domain (standard ONNX operators) is supported + if !opset.domain.is_empty() { + panic!("Only the standard ONNX domain is supported"); + } + + // Return true if the opset version is greater than or equal to min_version + opset.version >= min_version +} + +/// Verify that all operator sets in a model are supported. +/// +/// # Arguments +/// +/// * `opsets` - The operator sets to check +/// * `min_version` - The minimum supported version +/// +/// # Returns +/// +/// * `bool` - True if all opset versions are supported, false otherwise +pub fn verify_opsets(opsets: &[OperatorSetIdProto], min_version: i64) -> bool { + for opset in opsets { + if !check_opset_version(opset, min_version) { + return false; + } + } + true +} diff --git a/examples/image-classification-web/build.rs b/examples/image-classification-web/build.rs index 10d8f9d03c..6816bb8f8c 100644 --- a/examples/image-classification-web/build.rs +++ b/examples/image-classification-web/build.rs @@ -10,7 +10,7 @@ use burn_import::{burn::graph::RecordType, onnx::ModelGen}; const LABEL_SOURCE_FILE: &str = "src/model/label.txt"; const LABEL_DEST_FILE: &str = "model/label.rs"; -const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx"; +const INPUT_ONNX_FILE: &str = "src/model/squeezenet1_opset16.onnx"; const OUT_DIR: &str = "model/"; fn main() { diff --git a/examples/image-classification-web/src/model/squeezenet.rs b/examples/image-classification-web/src/model/squeezenet.rs index d796ae629f..a7918080a1 100644 --- a/examples/image-classification-web/src/model/squeezenet.rs +++ b/examples/image-classification-web/src/model/squeezenet.rs @@ -1,6 +1,6 @@ // Generated model from squeezenet1.onnx mod internal_model { - include!(concat!(env!("OUT_DIR"), "/model/squeezenet1.rs")); + include!(concat!(env!("OUT_DIR"), "/model/squeezenet1_opset16.rs")); } pub use internal_model::*; diff --git a/examples/image-classification-web/src/model/squeezenet1_opset16.onnx b/examples/image-classification-web/src/model/squeezenet1_opset16.onnx new file mode 100644 index 0000000000..6e2e02917d Binary files /dev/null and b/examples/image-classification-web/src/model/squeezenet1_opset16.onnx differ