Skip to content

Implement corrcoef function in keras.ops #21372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
from keras.src.ops.numpy import conj as conj
from keras.src.ops.numpy import conjugate as conjugate
from keras.src.ops.numpy import copy as copy
from keras.src.ops.numpy import corrcoef as corrcoef
from keras.src.ops.numpy import correlate as correlate
from keras.src.ops.numpy import cos as cos
from keras.src.ops.numpy import cosh as cosh
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from keras.src.ops.numpy import conj as conj
from keras.src.ops.numpy import conjugate as conjugate
from keras.src.ops.numpy import copy as copy
from keras.src.ops.numpy import corrcoef as corrcoef
from keras.src.ops.numpy import correlate as correlate
from keras.src.ops.numpy import cos as cos
from keras.src.ops.numpy import cosh as cosh
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
from keras.src.ops.numpy import conj as conj
from keras.src.ops.numpy import conjugate as conjugate
from keras.src.ops.numpy import copy as copy
from keras.src.ops.numpy import corrcoef as corrcoef
from keras.src.ops.numpy import correlate as correlate
from keras.src.ops.numpy import cos as cos
from keras.src.ops.numpy import cosh as cosh
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from keras.src.ops.numpy import conj as conj
from keras.src.ops.numpy import conjugate as conjugate
from keras.src.ops.numpy import copy as copy
from keras.src.ops.numpy import corrcoef as corrcoef
from keras.src.ops.numpy import correlate as correlate
from keras.src.ops.numpy import cos as cos
from keras.src.ops.numpy import cosh as cosh
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,11 @@ def logical_xor(x1, x2):
return jnp.logical_xor(x1, x2)


def corrcoef(x):
x = convert_to_tensor(x)
return jnp.corrcoef(x)


def correlate(x1, x2, mode="valid"):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,19 @@ def logical_xor(x1, x2):
return np.logical_xor(x1, x2)


def corrcoef(x):
if x.dtype in ["int64", "float64"]:
dtype = "float64"
elif x.dtype in ["bfloat16", "float16"]:
dtype = x.dtype
else:
dtype = config.floatx()

x = convert_to_tensor(x)

return np.corrcoef(x).astype(dtype)


def correlate(x1, x2, mode="valid"):
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
Expand Down
3 changes: 3 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ NumpyDtypeTest::test_kaiser
NumpyDtypeTest::test_bitwise
NumpyDtypeTest::test_ceil
NumpyDtypeTest::test_concatenate
NumpyDtypeTest::test_corrcoef
NumpyDtypeTest::test_correlate
NumpyDtypeTest::test_cross
NumpyDtypeTest::test_cumprod
Expand Down Expand Up @@ -81,6 +82,7 @@ NumpyOneInputOpsCorrectnessTest::test_hanning
NumpyOneInputOpsCorrectnessTest::test_kaiser
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
NumpyOneInputOpsCorrectnessTest::test_conj
NumpyOneInputOpsCorrectnessTest::test_corrcoef
NumpyOneInputOpsCorrectnessTest::test_correlate
NumpyOneInputOpsCorrectnessTest::test_cumprod
NumpyOneInputOpsCorrectnessTest::test_diag
Expand Down Expand Up @@ -151,6 +153,7 @@ NumpyTwoInputOpsCorrectnessTest::test_where
NumpyOneInputOpsDynamicShapeTest::test_angle
NumpyOneInputOpsDynamicShapeTest::test_bartlett
NumpyOneInputOpsDynamicShapeTest::test_blackman
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
NumpyOneInputOpsDynamicShapeTest::test_hamming
NumpyOneInputOpsDynamicShapeTest::test_hanning
NumpyOneInputOpsDynamicShapeTest::test_kaiser
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,12 @@ def logical_xor(x1, x2):
return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0))


def corrcoef(x):
raise NotImplementedError(
"`corrcoef` is not supported with openvino backend"
)


def correlate(x1, x2, mode="valid"):
raise NotImplementedError(
"`correlate` is not supported with openvino backend"
Expand Down
32 changes: 32 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,6 +2784,38 @@ def logical_xor(x1, x2):
return tf.math.logical_xor(x1, x2)


def corrcoef(x):
dtype = x.dtype
if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]:
dtype = config.floatx()
x = convert_to_tensor(x, dtype)

if tf.rank(x) == 0:
return tf.constant(float("nan"), dtype=config.floatx())

mean = tf.reduce_mean(x, axis=-1, keepdims=True)
x_centered = x - mean

num_samples = tf.cast(tf.shape(x)[-1], x.dtype)
cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / (
num_samples - 1
)

diag = tf.linalg.diag_part(cov_matrix)
stddev = tf.sqrt(tf.math.real(diag))

outer_std = tf.tensordot(stddev, stddev, axes=0)
outer_std = tf.cast(outer_std, cov_matrix.dtype)
correlation = cov_matrix / outer_std

correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0)
if correlation.dtype.is_complex:
imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0)
return tf.complex(correlation_clipped, imag_clipped)
else:
return correlation_clipped


def correlate(x1, x2, mode="valid"):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,17 @@ def logical_xor(x1, x2):
return torch.logical_xor(x1, x2)


def corrcoef(x):
x = convert_to_tensor(x)

if standardize_dtype(x.dtype) == "bool":
x = cast(x, config.floatx())
elif standardize_dtype(x.dtype) == "int64":
x = cast(x, "float64")

return torch.corrcoef(x)


def correlate(x1, x2, mode="valid"):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
29 changes: 29 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6888,6 +6888,35 @@ def logical_xor(x1, x2):
return backend.numpy.logical_xor(x1, x2)


class Corrcoef(Operation):
def call(self, x):
return backend.numpy.corrcoef(x)

def compute_output_spec(self, x):
dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx()))
if dtype == "int64":
dtype = "float64"
else:
dtype = dtypes.result_type(dtype, float)
return KerasTensor(x.shape, dtype=dtype)


@keras_export(["keras.ops.corrcoef", "keras.ops.numpy.corrcoef"])
def corrcoef(x):
"""Compute the Pearson correlation coefficient matrix.

Args:
x: A 2D tensor of shape `(N, D)`, where N is the number of variables
and D is the number of observations.

Returns:
A tensor of shape `(N, N)` representing the correlation matrix.
"""
if any_symbolic_tensors((x,)):
return Corrcoef().symbolic_call(x)
return backend.numpy.corrcoef(x)


class Correlate(Operation):
def __init__(self, mode="valid"):
super().__init__()
Expand Down
25 changes: 25 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,10 @@ def test_copy(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.copy(x).shape, (None, 3))

def test_corrcoef(self):
x = KerasTensor((3, None))
self.assertEqual(knp.corrcoef(x).shape, (3, None))

def test_cos(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.cos(x).shape, (None, 3))
Expand Down Expand Up @@ -3838,6 +3842,11 @@ def test_copy(self):
self.assertAllClose(knp.copy(x), np.copy(x))
self.assertAllClose(knp.Copy()(x), np.copy(x))

def test_corrcoef(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.corrcoef(x), np.corrcoef(x))
self.assertAllClose(knp.Corrcoef()(x), np.corrcoef(x))

def test_cos(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.cos(x), np.cos(x))
Expand Down Expand Up @@ -6591,6 +6600,22 @@ def test_copy(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_corrcoef(self, dtype):
import jax.numpy as jnp

x = knp.ones((2, 4), dtype=dtype)
x_jax = jnp.ones((2, 4), dtype=dtype)
expected_dtype = standardize_dtype(jnp.corrcoef(x_jax).dtype)

self.assertEqual(
standardize_dtype(knp.corrcoef(x).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Corrcoef().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
Expand Down