-
Notifications
You must be signed in to change notification settings - Fork 45
Description
Describe the problem.
API documentation of tf.keras.losses.SparseCategoricalCrossentropy mentions that one of the parameters can be None, but the implementation does not check None, it checks 'none' which is a string.
Describe the current behavior.
Throws a value error when running my test cases, which I think it shouldn't.
(redirected here from tensorflow/tensorflow#89246).
Describe the expected behavior.
It shouldn't throw a value error for using None instead of 'none' as per the API documentation.
- Do you want to contribute a PR? (yes/no): no
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):
Standalone code to reproduce the issue.
def test_reduction_none(self):
# Test with reduction set to None
y_true = np.array([0, 1, 2])
y_pred = np.array([[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9]])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=None)
loss = loss_fn(y_true, y_pred).numpy()
expected_loss = -np.log([0.9, 0.9, 0.9])
np.testing.assert_almost_equal(loss, expected_loss, decimal=5)
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Source code / logs.
Traceback (most recent call last):
File "/home/user/projects/api_guided_testgen/out/bug_detect_gpt4o/exec/basic_rag_apidoc/tf/tf.keras.losses.SparseCategoricalCrossentropy.py", line 49, in test_reduction_none
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=None)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 1026, in __init__
super().__init__(
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 262, in __init__
super().__init__(reduction=reduction, name=name)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 93, in __init__
losses_utils.ReductionV2.validate(reduction)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/utils/losses_utils.py", line 88, in validate
raise ValueError(
ValueError: Invalid Reduction Key: None. Expected keys are "('auto', 'none', 'sum', 'sum_over_batch_size')"
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.