diff --git a/mltu/preprocessors.py b/mltu/preprocessors.py index cb65ca1..acebe7d 100644 --- a/mltu/preprocessors.py +++ b/mltu/preprocessors.py @@ -6,6 +6,8 @@ import matplotlib import logging +from typing import Type + from . import Image from mltu.annotations.audio import Audio @@ -18,7 +20,7 @@ class ImageReader: """Read image from path and return image and label""" - def __init__(self, image_class: Image, log_level: int = logging.INFO, ) -> None: + def __init__(self, image_class: Type[Image], log_level: int = logging.INFO) -> None: self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(log_level) self._image_class = image_class diff --git a/mltu/tensorflow/callbacks.py b/mltu/tensorflow/callbacks.py index ea9479e..0c67e56 100644 --- a/mltu/tensorflow/callbacks.py +++ b/mltu/tensorflow/callbacks.py @@ -1,5 +1,6 @@ import os import tensorflow as tf +from pathlib import Path from keras.callbacks import Callback import logging @@ -14,7 +15,7 @@ def __init__( ) -> None: """ Converts the model to onnx format after training is finished. Args: - saved_model_path (str): Path to the saved .h5 model. + saved_model_path (str): Path to the saved model. metadata (dict, optional): Dictionary containing metadata to be added to the onnx model. Defaults to None. save_on_epoch_end (bool, optional): Save the onnx model on every epoch end. Defaults to False. """ @@ -35,35 +36,30 @@ def __init__( @staticmethod def model2onnx(model: tf.keras.Model, onnx_model_path: str): - try: - import tf2onnx - - # convert the model to onnx format - tf2onnx.convert.from_keras(model, output_path=onnx_model_path) + import tf2onnx - except Exception as e: - print(e) + # convert the model to onnx format + # NOTE: see here for more info https://github.com/keras-team/keras/issues/18430 + input_signature = [tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype, name='digit')] + tf2onnx.convert.from_keras(model, input_signature=input_signature, opset=13, output_path=onnx_model_path) @staticmethod def include_metadata(onnx_model_path: str, metadata: dict=None): - try: - if metadata and isinstance(metadata, dict): + if metadata and isinstance(metadata, dict): - import onnx - # Load the ONNX model - onnx_model = onnx.load(onnx_model_path) + import onnx + # Load the ONNX model + onnx_model = onnx.load(onnx_model_path) - # Add the metadata dictionary to the model's metadata_props attribute - for key, value in metadata.items(): - meta = onnx_model.metadata_props.add() - meta.key = key - meta.value = str(value) + # Add the metadata dictionary to the model's metadata_props attribute + for key, value in metadata.items(): + meta = onnx_model.metadata_props.add() + meta.key = key + meta.value = str(value) - # Save the modified ONNX model - onnx.save(onnx_model, onnx_model_path) + # Save the modified ONNX model + onnx.save(onnx_model, onnx_model_path) - except Exception as e: - print(e) def on_epoch_end(self, epoch: int, logs: dict=None): """ Converts the model to onnx format on every epoch end. """ @@ -72,8 +68,8 @@ def on_epoch_end(self, epoch: int, logs: dict=None): def on_train_end(self, logs=None): """ Converts the model to onnx format after training is finished. """ - self.model.load_weights(self.saved_model_path) - onnx_model_path = self.saved_model_path.replace(".h5", ".onnx") + self._model.load_weights(self.saved_model_path) + onnx_model_path = str(Path(self.saved_model_path).with_suffix('.onnx')) self.model2onnx(self.model, onnx_model_path) self.include_metadata(onnx_model_path, self.metadata) diff --git a/mltu/tensorflow/transformer/utils.py b/mltu/tensorflow/transformer/utils.py index 471f92e..461ae15 100644 --- a/mltu/tensorflow/transformer/utils.py +++ b/mltu/tensorflow/transformer/utils.py @@ -151,7 +151,7 @@ def update_state(self, y_true, y_pred, sample_weight=None): self.cer_accumulator.assign_add(tf.reduce_sum(distance)) # Increment the batch_counter by the batch size - self.batch_counter.assign_add(len(y_true)) + self.batch_counter.assign_add(y_true.shape[0]) def result(self): """ Computes and returns the metric result.