diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index d134640a3..02fa34a7e 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -233,11 +233,17 @@ def build(self, input_shape): # For each of the prunable weights, add mask and threshold variables for weight in self.prunable_weights: + # Under a mixed precision policy, variables report their "cast" dtype. + # However, we want to use the original dtype for mask and threshold. + if hasattr(weight, 'true_dtype'): + dtype = weight.true_dtype + else: + dtype = weight.dtype mask = self.add_weight( 'mask', shape=weight.shape, initializer=keras.initializers.get('ones'), - dtype=weight.dtype, + dtype=dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN, ) @@ -245,7 +251,7 @@ def build(self, input_shape): 'threshold', shape=[], initializer=keras.initializers.get('zeros'), - dtype=weight.dtype, + dtype=dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN, )