From 02d8199a8a603717a215e3e608c7cb3488e38277 Mon Sep 17 00:00:00 2001 From: Christian Steinmeyer Date: Wed, 28 Jun 2023 09:51:12 +0200 Subject: [PATCH] add original weight name as prefix for mask and threshold weights If there are multiple prunable weights in a wrapped layer, without this change, the wrapped model cannot be saved to h5 because of duplicate dataset names (the layer has multiple weights called "mask" and "threshold"). --- .../python/core/sparsity/keras/pruning_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 4845abcd7..2b98c6209 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -233,14 +233,14 @@ def build(self, input_shape): # For each of the prunable weights, add mask and threshold variables for weight in self.prunable_weights: mask = self.add_weight( - 'mask', + weight.name + '_mask', shape=weight.shape, initializer=tf.keras.initializers.get('ones'), dtype=weight.dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN) threshold = self.add_weight( - 'threshold', + weight.name + '_threshold', shape=[], initializer=tf.keras.initializers.get('zeros'), dtype=weight.dtype,