Skip to content

Fix dtype and assign* in AutocastVariable. #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2024
Merged

Conversation

copybara-service[bot]
Copy link

Fix dtype and assign* in AutocastVariable.

The dtype property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via .value() or implicitly by doing any operation.

This would cause seemingly correct things like this to fail with a dtype mismatch:

y = variable * tf.cast(x, variable.dtype)

Forcing users to write workarounds like:

v = variable.value()
y = variable * tf.cast(x, v.dtype)

Additionally, assign, assign_add, assign_sub expected the value to be of the true dtype, not the cast dtype.

This would cause seemingly correct things like this to fail with a dtype mismatch:

variable.assign(variable * factor)

(This is a common use case for non-trainable variables.)

Forcing users to write workarounds like:

variable.assign(tf.cast(variable * factor, variable.dtype))

This changes fixes these issues to make autocasting fully transparent:

  • dtype returns the cast dtype if applicable
  • assign* accept the cast dtype for the value if applicable

Note that this is consistent with how autocasting works in Keras 3.

@github-actions github-actions bot added the technique:pruning Regarding tfmot.sparsity.keras APIs and docs label Jul 1, 2024
@copybara-service copybara-service bot force-pushed the test_647135376 branch 6 times, most recently from af20f6f to 863e74d Compare July 8, 2024 22:04
The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation.

This would cause seemingly correct things like this to fail with a dtype mismatch:
```
y = variable * tf.cast(x, variable.dtype)
```

Forcing users to write workarounds like:
```
v = variable.value()
y = variable * tf.cast(x, v.dtype)
```

Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype.

This would cause seemingly correct things like this to fail with a dtype mismatch:
```
variable.assign(variable * factor)
```
(This is a common use case for non-trainable variables.)

Forcing users to write workarounds like:
```
variable.assign(tf.cast(variable * factor, variable.dtype))
```

This changes fixes these issues to make autocasting fully transparent:
- `dtype` returns the cast dtype if applicable
- `assign*` accept the cast dtype for the value if applicable

Note that this is consistent with how autocasting works in Keras 3.

PiperOrigin-RevId: 650386711
@copybara-service copybara-service bot merged commit ed3f017 into master Jul 8, 2024
@copybara-service copybara-service bot deleted the test_647135376 branch July 8, 2024 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
technique:pruning Regarding tfmot.sparsity.keras APIs and docs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant