diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index 36e203bc8..17c884497 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -19,12 +19,16 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer try: - from keras.engine import base_layer # pylint: disable=g-import-not-at-top + # OSS case. + import keras # pylint: disable=g-import-not-at-top + if hasattr(keras, 'src'): + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member + else: + from keras.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member except ImportError: - # Path as seen in pip packages as of TF/Keras 2.13. - from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top - -# TODO(b/139939526): move to public API. + # Internal case. + base_layer = tf._keras_internal.engine.base_layer # pylint: disable=protected-access layers = tf.keras.layers layers_compat_v1 = tf.compat.v1.keras.layers