Skip to content

Re-implementating co_occurrence() #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2023996
perf implement rust co-occurrence statistics
wenjie1991 Mar 18, 2025
e51f546
misc: change rust-py deps
wenjie1991 Mar 18, 2025
a5b8226
doc: improve the documentation
wenjie1991 Mar 18, 2025
f7ff293
add python re-implementation
MDLDan Mar 19, 2025
9af4252
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
9457767
Clean the tests and dependencies
wenjie1991 Mar 19, 2025
0ec6985
Merge branch 'numba-co-occurrence' of github.com:wenjie1991/squidpy i…
wenjie1991 Mar 19, 2025
92f3da5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
ad674ad
Merge branch 'main' into numba-co-occurrence
wenjie1991 Mar 19, 2025
057decc
Merge branch 'main' into numba-co-occurrence
timtreis Mar 28, 2025
26200d3
Merge branch 'main' into numba-co-occurrence
wenjie1991 Mar 28, 2025
c2a57ac
Merge branch 'scverse:main' into numba-co-occurrence
wenjie1991 Jun 12, 2025
52a5fae
Optimize memory access pattern & cache kernel
wenjie1991 Jun 12, 2025
cee646a
jit the outer function and parallelize
wenjie1991 Jun 12, 2025
05dd724
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
c490d45
Fix Mypy checking Typing error
wenjie1991 Jun 12, 2025
7b72292
Merge branch 'numba-co-occurrence' of github.com:wenjie1991/squidpy i…
wenjie1991 Jun 12, 2025
d0884ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
6e5d1df
Disable the cast typing in jit
wenjie1991 Jun 12, 2025
63702d1
merge
wenjie1991 Jun 12, 2025
a62128b
Try fix typing check error by mypy
wenjie1991 Jun 12, 2025
0aab856
Try: fix typing check error by mypy
wenjie1991 Jun 14, 2025
9ac2693
Try: fix typing check error by mypy
wenjie1991 Jun 14, 2025
5303de0
Merge branch 'main' into numba-co-occurrence
wenjie1991 Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 113 additions & 115 deletions src/squidpy/gr/_ppatterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from importlib.util import find_spec
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal

import numba.types as nt
import numpy as np
import pandas as pd
from anndata import AnnData
from numba import njit
from numba import njit, prange
from numpy.random import default_rng
from scanpy import logging as logg
from scanpy.get import _get_obs_rep
Expand Down Expand Up @@ -266,85 +267,120 @@ def _score_helper(
return score_perms


@njit(
ft[:, :, :](tt(it[:], 2), ft[:, :], it[:], ft[:], bl),
parallel=False,
fastmath=True,
)
@njit(parallel=True, fastmath=True, cache=True)
def _occur_count(
clust: tuple[NDArrayA, NDArrayA],
pw_dist: NDArrayA,
labs_unique: NDArrayA,
interval: NDArrayA,
same_split: bool,
spatial_x: NDArrayA, spatial_y: NDArrayA, thresholds: NDArrayA, label_idx: NDArrayA, n: int, k: int, l_val: int
) -> NDArrayA:
num = labs_unique.shape[0]
out = np.zeros((num, num, interval.shape[0] - 1), dtype=ft)

for idx in range(interval.shape[0] - 1):
co_occur = np.zeros((num, num), dtype=ft)
probs_con = np.zeros((num, num), dtype=ft)

thres_max = interval[idx + 1]
clust_x, clust_y = clust

# Modified to compute co-occurrence probability ratio over increasing radii sizes as opposed to discrete interval bins
# Need pw_dist > 0 to avoid counting a cell with itself as co-occurrence
idx_x, idx_y = np.nonzero((pw_dist <= thres_max) & (pw_dist > 0))
x = clust_x[idx_x]
y = clust_y[idx_y]
# Treat computing co-occurrence using the same split and different splits differently
# Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once
for i, j in zip(x, y): # noqa: B905 # cannot use strict=False because of numba
co_occur[i, j] += 1
if not same_split:
co_occur[j, i] += 1

# Prevent divison by zero errors when we have low cell counts/small intervals
probs_matrix = co_occur / np.sum(co_occur) if np.sum(co_occur) != 0 else np.zeros((num, num), dtype=ft)
probs = np.sum(probs_matrix, axis=0)

for c in labs_unique:
probs_conditional = (
co_occur[c] / np.sum(co_occur[c]) if np.sum(co_occur[c]) != 0 else np.zeros(num, dtype=ft)
)
probs_con[c, :] = np.zeros(num, dtype=ft)
for i in range(num):
if probs[i] == 0:
probs_con[c, i] = 0
else:
probs_con[c, i] = probs_conditional[i] / probs[i]
# Allocate a 2D array to store a flat local result per point.
k2 = k * k
local_results = np.zeros((n, l_val * k2), dtype=np.int32)

out[:, :, idx] = probs_con
for i in prange(n):
local_counts: NDArrayA = np.zeros(l_val * k2, dtype=np.int32)

return out
for j in range(n):
if i == j:
continue
dx = spatial_x[i] - spatial_x[j]
dy = spatial_y[i] - spatial_y[j]
d2 = dx * dx + dy * dy

pair = label_idx[i] * k + label_idx[j] # fixed in r–loop
base = pair * l_val # first cell for that pair

def _co_occurrence_helper(
idx_splits: Iterable[tuple[int, int]],
spatial_splits: Sequence[NDArrayA],
labs_splits: Sequence[NDArrayA],
labs_unique: NDArrayA,
interval: NDArrayA,
queue: SigQueue | None = None,
) -> pd.DataFrame:
out_lst = []
for t in idx_splits:
idx_x, idx_y = t
labs_x = labs_splits[idx_x]
labs_y = labs_splits[idx_y]
dist = pairwise_distances(spatial_splits[idx_x], spatial_splits[idx_y])
for r in range(l_val):
if d2 <= thresholds[r]:
local_counts[base + r] += 1

out = _occur_count((labs_x, labs_y), dist, labs_unique, interval, idx_x == idx_y)
out_lst.append(out)
local_results[i] = local_counts

if queue is not None:
queue.put(Signal.UPDATE)
# reduction and reshape stay the same
result_flat = local_results.sum(axis=0)
result: NDArrayA = result_flat.reshape(k, k, l_val).copy()

return result

if queue is not None:
queue.put(Signal.FINISH)

return out_lst
@njit(parallel=True, fastmath=True, cache=True)
def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs: NDArrayA) -> NDArrayA:
"""
Fast co-occurrence probability computation using the new numba-accelerated counting.

Parameters
----------
v_x : np.ndarray, float64
x–coordinates.
v_y : np.ndarray, float64
y–coordinates.
v_radium : np.ndarray, float64
Distance thresholds (in ascending order).
labs : np.ndarray
Cluster labels (as integers).

Returns
-------
occ_prob : np.ndarray
A 3D array of shape (k, k, len(v_radium)-1) containing the co-occurrence probabilities.
labs_unique : np.ndarray
Array of unique labels.
"""
n = len(v_x)
labs_unique = np.unique(labs)
k = len(labs_unique)
# l_val is the number of bins; here we assume the thresholds come from v_radium[1:].
l_val = len(v_radium) - 1
# Compute squared thresholds from the interval (skip the first value)
thresholds = (v_radium[1:]) ** 2

# Compute cco-occurence ounts.
counts = _occur_count(v_x, v_y, thresholds, labs, n, k, l_val)

# Compute co-occurrence probabilities for each threshold bin.
occ_prob = np.empty((k, k, l_val), dtype=np.float64)
for r in prange(l_val):
co_occur = counts[:, :, r].astype(np.float64)

# Compute the total count for this threshold.
total = 0.0
for i in range(k):
for j in range(k):
total += co_occur[i, j]

# Compute the normalized probability matrix.
probs_matrix = np.zeros((k, k), dtype=np.float64)
if total != 0.0:
for i in range(k):
for j in range(k):
probs_matrix[i, j] = co_occur[i, j] / total

probs = np.zeros(k, dtype=np.float32)
for j in range(k):
s = 0.0
for i in range(k):
s += probs_matrix[i, j]
probs[j] = s

# Compute conditional probabilities.
probs_con = np.zeros((k, k), dtype=np.float32)
for c in range(k):
row_sum = 0.0
for j in range(k):
row_sum += co_occur[c, j]
for i in range(k):
cond = 0.0
if row_sum != 0.0:
cond = co_occur[c, i] / row_sum
if probs[i] == 0.0:
probs_con[c, i] = 0.0
else:
probs_con[c, i] = cond / probs[i]

# Transpose to match (k, k, interval).
for i in range(k):
for c in range(k):
occ_prob[i, c, r] = probs_con[c, i]

return occ_prob


@d.dedent
Expand Down Expand Up @@ -387,18 +423,16 @@ def co_occurrence(
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
computed at ``interval``.
"""

if isinstance(adata, SpatialData):
adata = adata.table
_assert_categorical_obs(adata, key=cluster_key)
_assert_spatial_basis(adata, key=spatial_key)

spatial = adata.obsm[spatial_key].astype(fp)
original_clust = adata.obs[cluster_key]

# annotate cluster idx
clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)}
labs = np.array([clust_map[c] for c in original_clust], dtype=ip)
labs_unique = np.array(list(clust_map.values()), dtype=ip)

# create intervals thresholds
if isinstance(interval, int):
Expand All @@ -409,57 +443,21 @@ def co_occurrence(
if len(interval) <= 1:
raise ValueError(f"Expected interval to be of length `>= 2`, found `{len(interval)}`.")

n_obs = spatial.shape[0]
if n_splits is None:
size_arr = (n_obs**2 * spatial.itemsize) / 1024 / 1024 # calc expected mem usage
n_splits = 1
if size_arr > 2000:
while (n_obs / n_splits) > 2048:
n_splits += 1
logg.warning(
f"`n_splits` was automatically set to `{n_splits}` to "
f"prevent `{n_obs}x{n_obs}` distance matrix from being created"
)
n_splits = int(max(min(n_splits, n_obs), 1))

# split array and labels
spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s))
labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s))
# create idx array including unique combinations and self-comparison
x, y = np.triu_indices_from(np.empty((n_splits, n_splits)))
idx_splits = list(zip(x, y, strict=False))
spatial_x = spatial[:, 0]
spatial_y = spatial[:, 1]

n_jobs = _get_n_cores(n_jobs)
# Compute co-occurrence probabilities using the fast numba routine.
out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs)
start = logg.info(
f"Calculating co-occurrence probabilities for `{len(interval)}` intervals "
f"`{len(idx_splits)}` split combinations using `{n_jobs}` core(s)"
)

out_lst = parallelize(
_co_occurrence_helper,
collection=idx_splits,
extractor=chain.from_iterable,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)(
spatial_splits=spatial_splits,
labs_splits=labs_splits,
labs_unique=labs_unique,
interval=interval,
f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits"
)
out = list(out_lst)[0] if len(idx_splits) == 1 else sum(list(out_lst)) / len(idx_splits)

if copy:
logg.info("Finish", time=start)
return out, interval

_save_data(
adata,
attr="uns",
key=Key.uns.co_occurrence(cluster_key),
data={"occ": out, "interval": interval},
time=start,
adata, attr="uns", key=Key.uns.co_occurrence(cluster_key), data={"occ": out, "interval": interval}, time=start
)


Expand Down
Loading