diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 187f1156..6e3230be 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence +from importlib.util import find_spec from itertools import chain from typing import TYPE_CHECKING, Any, Literal @@ -10,7 +11,7 @@ import numpy as np import pandas as pd from anndata import AnnData -from numba import njit +from numba import njit, prange from numpy.random import default_rng from scanpy import logging as logg from scanpy.get import _get_obs_rep @@ -266,85 +267,120 @@ def _score_helper( return score_perms -@njit( - ft[:, :, :](tt(it[:], 2), ft[:, :], it[:], ft[:], bl), - parallel=False, - fastmath=True, -) +@njit(parallel=True, fastmath=True, cache=True) def _occur_count( - clust: tuple[NDArrayA, NDArrayA], - pw_dist: NDArrayA, - labs_unique: NDArrayA, - interval: NDArrayA, - same_split: bool, + spatial_x: NDArrayA, spatial_y: NDArrayA, thresholds: NDArrayA, label_idx: NDArrayA, n: int, k: int, l_val: int ) -> NDArrayA: - num = labs_unique.shape[0] - out = np.zeros((num, num, interval.shape[0] - 1), dtype=ft) - - for idx in range(interval.shape[0] - 1): - co_occur = np.zeros((num, num), dtype=ft) - probs_con = np.zeros((num, num), dtype=ft) - - thres_max = interval[idx + 1] - clust_x, clust_y = clust - - # Modified to compute co-occurrence probability ratio over increasing radii sizes as opposed to discrete interval bins - # Need pw_dist > 0 to avoid counting a cell with itself as co-occurrence - idx_x, idx_y = np.nonzero((pw_dist <= thres_max) & (pw_dist > 0)) - x = clust_x[idx_x] - y = clust_y[idx_y] - # Treat computing co-occurrence using the same split and different splits differently - # Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once - for i, j in zip(x, y): # noqa: B905 # cannot use strict=False because of numba - co_occur[i, j] += 1 - if not same_split: - co_occur[j, i] += 1 - - # Prevent divison by zero errors when we have low cell counts/small intervals - probs_matrix = co_occur / np.sum(co_occur) if np.sum(co_occur) != 0 else np.zeros((num, num), dtype=ft) - probs = np.sum(probs_matrix, axis=0) - - for c in labs_unique: - probs_conditional = ( - co_occur[c] / np.sum(co_occur[c]) if np.sum(co_occur[c]) != 0 else np.zeros(num, dtype=ft) - ) - probs_con[c, :] = np.zeros(num, dtype=ft) - for i in range(num): - if probs[i] == 0: - probs_con[c, i] = 0 - else: - probs_con[c, i] = probs_conditional[i] / probs[i] + # Allocate a 2D array to store a flat local result per point. + k2 = k * k + local_results = np.zeros((n, l_val * k2), dtype=np.int32) - out[:, :, idx] = probs_con + for i in prange(n): + local_counts: NDArrayA = np.zeros(l_val * k2, dtype=np.int32) - return out + for j in range(n): + if i == j: + continue + dx = spatial_x[i] - spatial_x[j] + dy = spatial_y[i] - spatial_y[j] + d2 = dx * dx + dy * dy + pair = label_idx[i] * k + label_idx[j] # fixed in r–loop + base = pair * l_val # first cell for that pair -def _co_occurrence_helper( - idx_splits: Iterable[tuple[int, int]], - spatial_splits: Sequence[NDArrayA], - labs_splits: Sequence[NDArrayA], - labs_unique: NDArrayA, - interval: NDArrayA, - queue: SigQueue | None = None, -) -> pd.DataFrame: - out_lst = [] - for t in idx_splits: - idx_x, idx_y = t - labs_x = labs_splits[idx_x] - labs_y = labs_splits[idx_y] - dist = pairwise_distances(spatial_splits[idx_x], spatial_splits[idx_y]) + for r in range(l_val): + if d2 <= thresholds[r]: + local_counts[base + r] += 1 - out = _occur_count((labs_x, labs_y), dist, labs_unique, interval, idx_x == idx_y) - out_lst.append(out) + local_results[i] = local_counts - if queue is not None: - queue.put(Signal.UPDATE) + # reduction and reshape stay the same + result_flat = local_results.sum(axis=0) + result: NDArrayA = result_flat.reshape(k, k, l_val).copy() + + return result - if queue is not None: - queue.put(Signal.FINISH) - return out_lst +@njit(parallel=True, fastmath=True, cache=True) +def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs: NDArrayA) -> NDArrayA: + """ + Fast co-occurrence probability computation using the new numba-accelerated counting. + + Parameters + ---------- + v_x : np.ndarray, float64 + x–coordinates. + v_y : np.ndarray, float64 + y–coordinates. + v_radium : np.ndarray, float64 + Distance thresholds (in ascending order). + labs : np.ndarray + Cluster labels (as integers). + + Returns + ------- + occ_prob : np.ndarray + A 3D array of shape (k, k, len(v_radium)-1) containing the co-occurrence probabilities. + labs_unique : np.ndarray + Array of unique labels. + """ + n = len(v_x) + labs_unique = np.unique(labs) + k = len(labs_unique) + # l_val is the number of bins; here we assume the thresholds come from v_radium[1:]. + l_val = len(v_radium) - 1 + # Compute squared thresholds from the interval (skip the first value) + thresholds = (v_radium[1:]) ** 2 + + # Compute cco-occurence ounts. + counts = _occur_count(v_x, v_y, thresholds, labs, n, k, l_val) + + # Compute co-occurrence probabilities for each threshold bin. + occ_prob = np.empty((k, k, l_val), dtype=np.float64) + for r in prange(l_val): + co_occur = counts[:, :, r].astype(np.float64) + + # Compute the total count for this threshold. + total = 0.0 + for i in range(k): + for j in range(k): + total += co_occur[i, j] + + # Compute the normalized probability matrix. + probs_matrix = np.zeros((k, k), dtype=np.float64) + if total != 0.0: + for i in range(k): + for j in range(k): + probs_matrix[i, j] = co_occur[i, j] / total + + probs = np.zeros(k, dtype=np.float32) + for j in range(k): + s = 0.0 + for i in range(k): + s += probs_matrix[i, j] + probs[j] = s + + # Compute conditional probabilities. + probs_con = np.zeros((k, k), dtype=np.float32) + for c in range(k): + row_sum = 0.0 + for j in range(k): + row_sum += co_occur[c, j] + for i in range(k): + cond = 0.0 + if row_sum != 0.0: + cond = co_occur[c, i] / row_sum + if probs[i] == 0.0: + probs_con[c, i] = 0.0 + else: + probs_con[c, i] = cond / probs[i] + + # Transpose to match (k, k, interval). + for i in range(k): + for c in range(k): + occ_prob[i, c, r] = probs_con[c, i] + + return occ_prob @d.dedent @@ -387,6 +423,7 @@ def co_occurrence( - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``interval``. """ + if isinstance(adata, SpatialData): adata = adata.table _assert_categorical_obs(adata, key=cluster_key) @@ -394,11 +431,8 @@ def co_occurrence( spatial = adata.obsm[spatial_key].astype(fp) original_clust = adata.obs[cluster_key] - - # annotate cluster idx clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} labs = np.array([clust_map[c] for c in original_clust], dtype=ip) - labs_unique = np.array(list(clust_map.values()), dtype=ip) # create intervals thresholds if isinstance(interval, int): @@ -409,57 +443,21 @@ def co_occurrence( if len(interval) <= 1: raise ValueError(f"Expected interval to be of length `>= 2`, found `{len(interval)}`.") - n_obs = spatial.shape[0] - if n_splits is None: - size_arr = (n_obs**2 * spatial.itemsize) / 1024 / 1024 # calc expected mem usage - n_splits = 1 - if size_arr > 2000: - while (n_obs / n_splits) > 2048: - n_splits += 1 - logg.warning( - f"`n_splits` was automatically set to `{n_splits}` to " - f"prevent `{n_obs}x{n_obs}` distance matrix from being created" - ) - n_splits = int(max(min(n_splits, n_obs), 1)) - - # split array and labels - spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s)) - labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s)) - # create idx array including unique combinations and self-comparison - x, y = np.triu_indices_from(np.empty((n_splits, n_splits))) - idx_splits = list(zip(x, y, strict=False)) + spatial_x = spatial[:, 0] + spatial_y = spatial[:, 1] - n_jobs = _get_n_cores(n_jobs) + # Compute co-occurrence probabilities using the fast numba routine. + out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs) start = logg.info( - f"Calculating co-occurrence probabilities for `{len(interval)}` intervals " - f"`{len(idx_splits)}` split combinations using `{n_jobs}` core(s)" - ) - - out_lst = parallelize( - _co_occurrence_helper, - collection=idx_splits, - extractor=chain.from_iterable, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - )( - spatial_splits=spatial_splits, - labs_splits=labs_splits, - labs_unique=labs_unique, - interval=interval, + f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits" ) - out = list(out_lst)[0] if len(idx_splits) == 1 else sum(list(out_lst)) / len(idx_splits) if copy: logg.info("Finish", time=start) return out, interval _save_data( - adata, - attr="uns", - key=Key.uns.co_occurrence(cluster_key), - data={"occ": out, "interval": interval}, - time=start, + adata, attr="uns", key=Key.uns.co_occurrence(cluster_key), data={"occ": out, "interval": interval}, time=start )