From 97d37020ec810cd389a99ed5a0cca787c4a1fd63 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Wed, 28 May 2025 21:01:38 +0900 Subject: [PATCH 1/7] Add corrcoef for ops --- keras/src/backend/jax/numpy.py | 5 +++++ keras/src/backend/numpy/numpy.py | 5 +++++ keras/src/backend/tensorflow/numpy.py | 28 +++++++++++++++++++++++++++ keras/src/backend/torch/numpy.py | 5 +++++ 4 files changed, 43 insertions(+) diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index e2376a15462d..d49673d74c4a 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -1338,6 +1338,11 @@ def logical_xor(x1, x2): return jnp.logical_xor(x1, x2) +def corrcoef(x): + x = convert_to_tensor(x) + return jnp.corrcoef(x) + + def correlate(x1, x2, mode="valid"): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index aa4a686050f5..ae7caae5a40b 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1256,6 +1256,11 @@ def logical_xor(x1, x2): return np.logical_xor(x1, x2) +def corrcoef(x): + x = convert_to_tensor(x) + return np.corrcoef(x) + + def correlate(x1, x2, mode="valid"): dtype = dtypes.result_type( getattr(x1, "dtype", type(x1)), diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 524ba1b499ec..df3a378fd77d 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2784,6 +2784,34 @@ def logical_xor(x1, x2): return tf.math.logical_xor(x1, x2) +def corrcoef(x): + x = convert_to_tensor(x, dtype=config.floatx()) + + if tf.rank(x) == 0: + return tf.constant(float("nan"), dtype=config.floatx()) + + mean = tf.reduce_mean(x, axis=1, keepdims=True) + x_centered = x - mean + + cov_matrix = tf.matmul(x_centered, x_centered, transpose_b=True) + num_samples = tf.cast(tf.shape(x)[1], x.dtype) + cov_matrix /= num_samples - 1 + + diag = tf.linalg.diag_part(cov_matrix) + stddev = tf.sqrt(tf.math.real(diag)) + + outer_std = tf.tensordot(stddev, stddev, axes=0) + + correlation = cov_matrix / outer_std + correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0) + + if tf.experimental.numpy.iscomplexobj(correlation): + imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0) + return tf.complex(correlation_clipped, imag_clipped) + else: + return correlation_clipped + + def correlate(x1, x2, mode="valid"): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index e92a6c9ee759..1020f92a9a99 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1740,6 +1740,11 @@ def logical_xor(x1, x2): return torch.logical_xor(x1, x2) +def corrcoef(x): + x = convert_to_tensor(x) + return torch.corrcoef(x) + + def correlate(x1, x2, mode="valid"): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) From d4933dd15cb26ffdbeef4082fd58b1f8bca91c0e Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 9 Jun 2025 20:39:04 +0900 Subject: [PATCH 2/7] Update method for complex case --- keras/src/backend/tensorflow/numpy.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index df3a378fd77d..27ae82b1ba82 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2785,7 +2785,7 @@ def logical_xor(x1, x2): def corrcoef(x): - x = convert_to_tensor(x, dtype=config.floatx()) + x = convert_to_tensor(x) if tf.rank(x) == 0: return tf.constant(float("nan"), dtype=config.floatx()) @@ -2793,19 +2793,20 @@ def corrcoef(x): mean = tf.reduce_mean(x, axis=1, keepdims=True) x_centered = x - mean - cov_matrix = tf.matmul(x_centered, x_centered, transpose_b=True) num_samples = tf.cast(tf.shape(x)[1], x.dtype) - cov_matrix /= num_samples - 1 + cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / ( + num_samples - 1 + ) diag = tf.linalg.diag_part(cov_matrix) stddev = tf.sqrt(tf.math.real(diag)) outer_std = tf.tensordot(stddev, stddev, axes=0) - + outer_std = tf.cast(outer_std, cov_matrix.dtype) correlation = cov_matrix / outer_std - correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0) - if tf.experimental.numpy.iscomplexobj(correlation): + correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0) + if correlation.dtype.is_complex: imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0) return tf.complex(correlation_clipped, imag_clipped) else: From 4266fc84d4b4fb335920267e8206f34b563b5fc6 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 9 Jun 2025 21:29:05 +0900 Subject: [PATCH 3/7] Add init.py for corrcoef --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + .../api/_tf_keras/keras/ops/numpy/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/api/ops/numpy/__init__.py | 1 + keras/src/ops/numpy.py | 29 +++++++++++++++++++ keras/src/ops/numpy_test.py | 20 +++++++++++++ 6 files changed, 53 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index c5770a93c49d..6cc134044233 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -155,6 +155,7 @@ from keras.src.ops.numpy import conj as conj from keras.src.ops.numpy import conjugate as conjugate from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef from keras.src.ops.numpy import correlate as correlate from keras.src.ops.numpy import cos as cos from keras.src.ops.numpy import cosh as cosh diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 966613cb28f7..ab4252afd016 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -44,6 +44,7 @@ from keras.src.ops.numpy import conj as conj from keras.src.ops.numpy import conjugate as conjugate from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef from keras.src.ops.numpy import correlate as correlate from keras.src.ops.numpy import cos as cos from keras.src.ops.numpy import cosh as cosh diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index c5770a93c49d..6cc134044233 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -155,6 +155,7 @@ from keras.src.ops.numpy import conj as conj from keras.src.ops.numpy import conjugate as conjugate from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef from keras.src.ops.numpy import correlate as correlate from keras.src.ops.numpy import cos as cos from keras.src.ops.numpy import cosh as cosh diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 966613cb28f7..ab4252afd016 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -44,6 +44,7 @@ from keras.src.ops.numpy import conj as conj from keras.src.ops.numpy import conjugate as conjugate from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef from keras.src.ops.numpy import correlate as correlate from keras.src.ops.numpy import cos as cos from keras.src.ops.numpy import cosh as cosh diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 1b3a8a64f3b0..5aae8de205ab 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -6888,6 +6888,35 @@ def logical_xor(x1, x2): return backend.numpy.logical_xor(x1, x2) +class Corrcoef(Operation): + def call(self, x): + return backend.numpy.corrcoef(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = "float64" + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.corrcoef", "keras.ops.numpy.corrcoef"]) +def corrcoef(x): + """Compute the Pearson correlation coefficient matrix. + + Args: + x: A 2D tensor of shape (N, D), where N is the number of variables + and D is the number of observations. + + Returns: + A tensor of shape (N, N) representing the correlation matrix. + """ + if any_symbolic_tensors((x,)): + return Corrcoef().symbolic_call(x) + return backend.numpy.corrcoef(x) + + class Correlate(Operation): def __init__(self, mode="valid"): super().__init__() diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 769dcdb4be31..2ee129c91f47 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -6591,6 +6591,26 @@ def test_copy(self, dtype): expected_dtype, ) + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_corrcoef(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((3,), dtype=dtype1) + x1_jax = jnp.ones((3,), dtype=dtype1) + expected_dtype = standardize_dtype(jnp.corrcoef(x1_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.corrcoef(x1).dtype), expected_dtype + ) + + self.assertEqual( + standardize_dtype(knp.Corrcoef().symbolic_call(x1).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) ) From 2a764b3b1836f7b35dfc3b4f104ffc39c716331d Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Tue, 10 Jun 2025 21:02:47 +0900 Subject: [PATCH 4/7] Update code for test case --- keras/src/backend/numpy/numpy.py | 10 +++++++++- keras/src/backend/tensorflow/numpy.py | 5 ++++- keras/src/backend/torch/numpy.py | 6 ++++++ keras/src/ops/numpy_test.py | 27 ++++++++++++++++----------- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index ae7caae5a40b..04514fbc97c3 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1257,8 +1257,16 @@ def logical_xor(x1, x2): def corrcoef(x): + if x.dtype in ["int64", "float64"]: + dtype = "float64" + elif x.dtype in ["bfloat16", "float16"]: + dtype = x.dtype + else: + dtype = config.floatx() + x = convert_to_tensor(x) - return np.corrcoef(x) + + return np.corrcoef(x).astype(dtype) def correlate(x1, x2, mode="valid"): diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 27ae82b1ba82..77eff57c7bb2 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2785,7 +2785,10 @@ def logical_xor(x1, x2): def corrcoef(x): - x = convert_to_tensor(x) + dtype = x.dtype + if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + x = convert_to_tensor(x, dtype) if tf.rank(x) == 0: return tf.constant(float("nan"), dtype=config.floatx()) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 1020f92a9a99..0f65d1b2a58a 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1742,6 +1742,12 @@ def logical_xor(x1, x2): def corrcoef(x): x = convert_to_tensor(x) + + if standardize_dtype(x.dtype) == "bool": + x = cast(x, config.floatx()) + elif standardize_dtype(x.dtype) == "int64": + x = cast(x, "float64") + return torch.corrcoef(x) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 2ee129c91f47..569620d6daf5 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1313,6 +1313,10 @@ def test_copy(self): x = KerasTensor((None, 3)) self.assertEqual(knp.copy(x).shape, (None, 3)) + def test_corrcoef(self): + x = KerasTensor((3, None)) + self.assertEqual(knp.corrcoef(x).shape, (3, None)) + def test_cos(self): x = KerasTensor((None, 3)) self.assertEqual(knp.cos(x).shape, (None, 3)) @@ -3838,6 +3842,11 @@ def test_copy(self): self.assertAllClose(knp.copy(x), np.copy(x)) self.assertAllClose(knp.Copy()(x), np.copy(x)) + def test_corrcoef(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.corrcoef(x), np.corrcoef(x)) + self.assertAllClose(knp.Corrcoef()(x), np.corrcoef(x)) + def test_cos(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.cos(x), np.cos(x)) @@ -6591,23 +6600,19 @@ def test_copy(self, dtype): expected_dtype, ) - @parameterized.named_parameters( - named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) - ) - def test_corrcoef(self, dtypes): + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_corrcoef(self, dtype): import jax.numpy as jnp - dtype1, dtype2 = dtypes - x1 = knp.ones((3,), dtype=dtype1) - x1_jax = jnp.ones((3,), dtype=dtype1) - expected_dtype = standardize_dtype(jnp.corrcoef(x1_jax).dtype) + x = knp.ones((2, 4), dtype=dtype) + x_jax = jnp.ones((2, 4), dtype=dtype) + expected_dtype = standardize_dtype(jnp.corrcoef(x_jax).dtype) self.assertEqual( - standardize_dtype(knp.corrcoef(x1).dtype), expected_dtype + standardize_dtype(knp.corrcoef(x).dtype), expected_dtype ) - self.assertEqual( - standardize_dtype(knp.Corrcoef().symbolic_call(x1).dtype), + standardize_dtype(knp.Corrcoef().symbolic_call(x).dtype), expected_dtype, ) From 16ebfd9a1d1a5408425dde10301c4f1cacb7cb4c Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Tue, 10 Jun 2025 21:07:13 +0900 Subject: [PATCH 5/7] Update excluded_concrete_tests.txt for openvino --- keras/src/backend/openvino/excluded_concrete_tests.txt | 3 +++ keras/src/backend/openvino/numpy.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index b76e17ac96cc..94a63d634a49 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -16,6 +16,7 @@ NumpyDtypeTest::test_kaiser NumpyDtypeTest::test_bitwise NumpyDtypeTest::test_ceil NumpyDtypeTest::test_concatenate +NumpyDtypeTest::test_corrcoef NumpyDtypeTest::test_correlate NumpyDtypeTest::test_cross NumpyDtypeTest::test_cumprod @@ -81,6 +82,7 @@ NumpyOneInputOpsCorrectnessTest::test_hanning NumpyOneInputOpsCorrectnessTest::test_kaiser NumpyOneInputOpsCorrectnessTest::test_bitwise_invert NumpyOneInputOpsCorrectnessTest::test_conj +NumpyOneInputOpsCorrectnessTest::test_corrcoef NumpyOneInputOpsCorrectnessTest::test_correlate NumpyOneInputOpsCorrectnessTest::test_cumprod NumpyOneInputOpsCorrectnessTest::test_diag @@ -151,6 +153,7 @@ NumpyTwoInputOpsCorrectnessTest::test_where NumpyOneInputOpsDynamicShapeTest::test_angle NumpyOneInputOpsDynamicShapeTest::test_bartlett NumpyOneInputOpsDynamicShapeTest::test_blackman +NumpyOneInputOpsDynamicShapeTest::test_corrcoef NumpyOneInputOpsDynamicShapeTest::test_hamming NumpyOneInputOpsDynamicShapeTest::test_hanning NumpyOneInputOpsDynamicShapeTest::test_kaiser diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 516d6a0f1bd2..a56ef5f5d8e5 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1641,6 +1641,12 @@ def logical_xor(x1, x2): return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0)) +def corrcoef(x): + raise NotImplementedError( + "`corrcoef` is not supported with openvino backend" + ) + + def correlate(x1, x2, mode="valid"): raise NotImplementedError( "`correlate` is not supported with openvino backend" From 9679f7adb71878b2283848dcf3c6611938e4f8c9 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Tue, 10 Jun 2025 21:10:16 +0900 Subject: [PATCH 6/7] Update axis for corrcoef on tf --- keras/src/backend/tensorflow/numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 77eff57c7bb2..1af52569150a 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2793,10 +2793,10 @@ def corrcoef(x): if tf.rank(x) == 0: return tf.constant(float("nan"), dtype=config.floatx()) - mean = tf.reduce_mean(x, axis=1, keepdims=True) + mean = tf.reduce_mean(x, axis=-1, keepdims=True) x_centered = x - mean - num_samples = tf.cast(tf.shape(x)[1], x.dtype) + num_samples = tf.cast(tf.shape(x)[-1], x.dtype) cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / ( num_samples - 1 ) From 425d676ea551984d3a97080efe15ef2c95e40c9b Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Fri, 13 Jun 2025 19:37:29 +0900 Subject: [PATCH 7/7] update docstrings --- keras/src/ops/numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 5aae8de205ab..b35b86bb0a9f 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -6906,11 +6906,11 @@ def corrcoef(x): """Compute the Pearson correlation coefficient matrix. Args: - x: A 2D tensor of shape (N, D), where N is the number of variables + x: A 2D tensor of shape `(N, D)`, where N is the number of variables and D is the number of observations. Returns: - A tensor of shape (N, N) representing the correlation matrix. + A tensor of shape `(N, N)` representing the correlation matrix. """ if any_symbolic_tensors((x,)): return Corrcoef().symbolic_call(x)