Skip to content

Fixed checking length for symbolic tensor + make Model2onnx agnostic to the type of model #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mltu/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import matplotlib
import logging

from typing import Type

from . import Image
from mltu.annotations.audio import Audio

Expand All @@ -18,7 +20,7 @@

class ImageReader:
"""Read image from path and return image and label"""
def __init__(self, image_class: Image, log_level: int = logging.INFO, ) -> None:
def __init__(self, image_class: Type[Image], log_level: int = logging.INFO) -> None:
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(log_level)
self._image_class = image_class
Expand Down
44 changes: 20 additions & 24 deletions mltu/tensorflow/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tensorflow as tf
from pathlib import Path
from keras.callbacks import Callback

import logging
Expand All @@ -14,7 +15,7 @@ def __init__(
) -> None:
""" Converts the model to onnx format after training is finished.
Args:
saved_model_path (str): Path to the saved .h5 model.
saved_model_path (str): Path to the saved model.
metadata (dict, optional): Dictionary containing metadata to be added to the onnx model. Defaults to None.
save_on_epoch_end (bool, optional): Save the onnx model on every epoch end. Defaults to False.
"""
Expand All @@ -35,35 +36,30 @@ def __init__(

@staticmethod
def model2onnx(model: tf.keras.Model, onnx_model_path: str):
try:
import tf2onnx

# convert the model to onnx format
tf2onnx.convert.from_keras(model, output_path=onnx_model_path)
import tf2onnx

except Exception as e:
print(e)
# convert the model to onnx format
# NOTE: see here for more info https://github.com/keras-team/keras/issues/18430
input_signature = [tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype, name='digit')]
tf2onnx.convert.from_keras(model, input_signature=input_signature, opset=13, output_path=onnx_model_path)

@staticmethod
def include_metadata(onnx_model_path: str, metadata: dict=None):
try:
if metadata and isinstance(metadata, dict):
if metadata and isinstance(metadata, dict):

import onnx
# Load the ONNX model
onnx_model = onnx.load(onnx_model_path)
import onnx
# Load the ONNX model
onnx_model = onnx.load(onnx_model_path)

# Add the metadata dictionary to the model's metadata_props attribute
for key, value in metadata.items():
meta = onnx_model.metadata_props.add()
meta.key = key
meta.value = str(value)
# Add the metadata dictionary to the model's metadata_props attribute
for key, value in metadata.items():
meta = onnx_model.metadata_props.add()
meta.key = key
meta.value = str(value)

# Save the modified ONNX model
onnx.save(onnx_model, onnx_model_path)
# Save the modified ONNX model
onnx.save(onnx_model, onnx_model_path)

except Exception as e:
print(e)

def on_epoch_end(self, epoch: int, logs: dict=None):
""" Converts the model to onnx format on every epoch end. """
Expand All @@ -72,8 +68,8 @@ def on_epoch_end(self, epoch: int, logs: dict=None):

def on_train_end(self, logs=None):
""" Converts the model to onnx format after training is finished. """
self.model.load_weights(self.saved_model_path)
onnx_model_path = self.saved_model_path.replace(".h5", ".onnx")
self._model.load_weights(self.saved_model_path)
onnx_model_path = str(Path(self.saved_model_path).with_suffix('.onnx'))
self.model2onnx(self.model, onnx_model_path)
self.include_metadata(onnx_model_path, self.metadata)

Expand Down
2 changes: 1 addition & 1 deletion mltu/tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
self.cer_accumulator.assign_add(tf.reduce_sum(distance))

# Increment the batch_counter by the batch size
self.batch_counter.assign_add(len(y_true))
self.batch_counter.assign_add(y_true.shape[0])

def result(self):
""" Computes and returns the metric result.
Expand Down