Skip to content

to_onnx return ONNXProgram #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bde9614
feat: return `ONNXProgram` when exporting with dynamo=True.
GdoongMathew May 11, 2025
a966c0b
test: add to_onnx(dynamo=True) unittests.
GdoongMathew May 11, 2025
e7342e3
fix: add ignore filter in pyproject.toml
GdoongMathew May 11, 2025
3ee3ea9
fix: change the return type annotation of `to_onnx`.
GdoongMathew May 21, 2025
bc81215
test: add parametrized `dynamo` to test `test_if_inference_output_is_…
GdoongMathew May 21, 2025
236f1a0
test: add difference check in `test_model_return_type`.
GdoongMathew May 21, 2025
019125d
fix: fix unittest.
GdoongMathew May 21, 2025
9f5e604
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew May 30, 2025
791d777
deps: bump typing_extension for onnxscript.
GdoongMathew Jun 2, 2025
453e63f
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 2, 2025
acdf3c1
deps: bump typing_extension for onnxscript.
GdoongMathew Jun 2, 2025
e046d27
deps: bump onnxscript upper bound.
GdoongMathew Jun 3, 2025
a0a7d1f
test: add test `test_model_onnx_export_missing_onnxscript`.
GdoongMathew Jun 5, 2025
7aae865
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 6, 2025
aca9fd1
revert typing-extension bump.
GdoongMathew Jun 7, 2025
1396f35
lower the min_torch version in unittest.
GdoongMathew Jun 7, 2025
8f050ea
feat: enable ONNXProgram export on torch 2.5.0
GdoongMathew Jun 16, 2025
3938c73
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 16, 2025
ce3e6b7
extensions
Borda Jun 16, 2025
c31a3f6
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 16, 2025
40b1449
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 16, 2025
a470fe8
ds
Borda Jun 18, 2025
9e4a494
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2025
b08e465
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 18, 2025
67af423
dep: test fixing pydantic version.
GdoongMathew Jun 18, 2025
0e4cb80
Revert "dep: test fixing pydantic version."
GdoongMathew Jun 18, 2025
b26072d
dep: add serve deps.
GdoongMathew Jun 18, 2025
d1b8597
ci: test.
GdoongMathew Jun 18, 2025
ec638c3
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 26, 2025
aa951fd
update onnxscript upperbound.
GdoongMathew Jun 26, 2025
9491953
align with ce3e6b7
GdoongMathew Jun 28, 2025
819d3c8
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jul 6, 2025
f014639
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jul 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ markers = [
]
filterwarnings = [
"error::FutureWarning",
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
]
xfail_strict = true
junit_duration_report = "call"
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0
numpy >=1.17.2, <1.27.0
onnx >=1.12.0, <1.18.0
onnxruntime >=1.12.0, <1.21.0
onnxscript >= 0.2.2, <0.2.6
psutil <7.0.1 # for `DeviceStatsMonitor`
pandas >1.0, <2.3.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
Expand Down
20 changes: 17 additions & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.onnx import ONNXProgram

_ONNX_AVAILABLE = RequirementCache("onnx")
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1360,12 +1362,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
)

@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
def to_onnx(
self,
file_path: Union[str, Path, BytesIO, None] = None,
input_sample: Optional[Any] = None,
**kwargs: Any,
) -> Union["ONNXProgram", None]:
"""Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
input_sample: An input for tracing. Default: None (Use self.example_input_array)

**kwargs: Will be passed to torch.onnx.export function.

Example::
Expand All @@ -1386,6 +1394,11 @@ def forward(self, x):
if not _ONNX_AVAILABLE:
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")

if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE:
raise ModuleNotFoundError(
f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."
)

mode = self.training

if input_sample is None:
Expand All @@ -1402,8 +1415,9 @@ def forward(self, x):
file_path = str(file_path) if isinstance(file_path, Path) else file_path
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)
return ret

@torch.no_grad()
def to_torchscript(
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
onnxscript: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand All @@ -64,6 +65,7 @@ def _runif_reasons(
psutil: Require that psutil is installed.
sklearn: Require that scikit-learn is installed.
onnx: Require that onnx is installed.
onnxscript: Require that onnxscript is installed.

"""

Expand Down Expand Up @@ -96,4 +98,7 @@ def _runif_reasons(
if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")

if onnxscript and not _ONNXSCRIPT_AVAILABLE:
reasons.append("onnxscript")

return reasons, kwargs
10 changes: 10 additions & 0 deletions tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,13 @@ def to_numpy(tensor):

# compare ONNX Runtime and PyTorch results
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)


@RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True)
def test_model_return_type():
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
model.eval()

ret = model.to_onnx(dynamo=True)
assert isinstance(ret, torch.onnx.ONNXProgram)
Loading