Skip to content

Commit ed3f017

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fix dtype and assign* in AutocastVariable.
The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` y = variable * tf.cast(x, variable.dtype) ``` Forcing users to write workarounds like: ``` v = variable.value() y = variable * tf.cast(x, v.dtype) ``` Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype. This would cause seemingly correct things like this to fail with a dtype mismatch: ``` variable.assign(variable * factor) ``` (This is a common use case for non-trainable variables.) Forcing users to write workarounds like: ``` variable.assign(tf.cast(variable * factor, variable.dtype)) ``` This changes fixes these issues to make autocasting fully transparent: - `dtype` returns the cast dtype if applicable - `assign*` accept the cast dtype for the value if applicable Note that this is consistent with how autocasting works in Keras 3. PiperOrigin-RevId: 650386711
1 parent 909a2a4 commit ed3f017

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,25 @@ def build(self, input_shape):
233233

234234
# For each of the prunable weights, add mask and threshold variables
235235
for weight in self.prunable_weights:
236+
# Under a mixed precision policy, variables report their "cast" dtype.
237+
# However, we want to use the original dtype for mask and threshold.
238+
if hasattr(weight, 'true_dtype'):
239+
dtype = weight.true_dtype
240+
else:
241+
dtype = weight.dtype
236242
mask = self.add_weight(
237243
'mask',
238244
shape=weight.shape,
239245
initializer=keras.initializers.get('ones'),
240-
dtype=weight.dtype,
246+
dtype=dtype,
241247
trainable=False,
242248
aggregation=tf.VariableAggregation.MEAN,
243249
)
244250
threshold = self.add_weight(
245251
'threshold',
246252
shape=[],
247253
initializer=keras.initializers.get('zeros'),
248-
dtype=weight.dtype,
254+
dtype=dtype,
249255
trainable=False,
250256
aggregation=tf.VariableAggregation.MEAN,
251257
)

0 commit comments

Comments
 (0)