Skip to content

scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x #3280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
24 changes: 11 additions & 13 deletions docs/release-notes/1.10.3.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
(v1.10.3)=
### 1.10.3 {small}`2024-09-17`
### 1.10.3 {small}`the future`

#### Bug fixes
```{rubric} Development features
```

- Prevent empty control gene set in {func}`~scanpy.tl.score_genes` {smaller}`M Müller` ({pr}`2875`)
- Fix `subset=True` of {func}`~scanpy.pp.highly_variable_genes` when `flavor` is `seurat` or `cell_ranger`, and `batch_key!=None` {smaller}`E Roellin` ({pr}`3042`)
- Add compatibility with {mod}`numpy` 2.0 {smaller}`P Angerer` {pr}`3065` and ({pr}`3115`)
- Fix `legend_loc` argument in {func}`scanpy.pl.embedding` not accepting matplotlib parameters {smaller}`P Angerer` ({pr}`3163`)
- Fix dispersion cutoff in {func}`~scanpy.pp.highly_variable_genes` in presence of `NaN`s {smaller}`P Angerer` ({pr}`3176`)
- Fix axis labeling for swapped axes in {func}`~scanpy.pl.rank_genes_groups_stacked_violin` {smaller}`Ilan Gold` ({pr}`3196`)
- Upper bound dask on account of {issue}`scverse/anndata#1579` {smaller}`Ilan Gold` ({pr}`3217`)
- The [fa2-modified][] package replaces [forceatlas2][] for the latter’s lack of maintenance {smaller}`A Alam` ({pr}`3220`)
```{rubric} Docs
```

[fa2-modified]: https://github.com/AminAlam/fa2_modified
[forceatlas2]: https://github.com/bhargavchippada/forceatlas2
```{rubric} Bug fixes
```

```{rubric} Performance
```
* Speed up _get_mean_var used in {func}`~scanpy.pp.scale` {pr}`3099` {smaller}`P Ashish & S Dicks`
42 changes: 39 additions & 3 deletions src/scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,53 @@
def _get_mean_var(
X: _SupportedArray, *, axis: Literal[0, 1] = 0
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
if isinstance(X, sparse.spmatrix):
mean, var = sparse_mean_variance_axis(X, axis=axis)
if isinstance(X, np.ndarray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a second code path that handles np.ndarray, you should replace the existing one above:

@axis_mean.register(np.ndarray)
def _(X: np.ndarray, ...): ...

n_threads = numba.get_num_threads()
mean, var = _compute_mean_var(X, axis=axis, n_threads=n_threads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to sparse?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still try to integrate all the stuff it's a mess though.

else:
mean = axis_mean(X, axis=axis, dtype=np.float64)
mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64)
var = mean_sq - mean**2
# enforce R convention (unbiased estimator) for variance
if X.shape[axis] != 1:
var *= X.shape[axis] / (X.shape[axis] - 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove erroneous diffs

return mean, var

@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
if axis == 0:
axis_i = 1
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
n = X.shape[axis]
for i in numba.prange(n_threads):
for r in range(i, n, n_threads):
for c in range(X.shape[axis_i]):
value = X[r, c]
sums[i, c] += value
sums_squared[i, c] += value * value
for c in numba.prange(X.shape[axis_i]):
sum_ = sums[:, c].sum()
mean[c] = sum_ / n
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1)
else:
axis_i = 0
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
Comment on lines +74 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull this out of the if branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean these assignments. When you have two branches, and both start with the same 3 lines, just do those before the if statement instead.

for r in numba.prange(X.shape[0]):
for c in range(X.shape[1]):
value = X[r, c]
mean[r] += value
var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no return statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have updated the code please check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the tests aren’t passing, so it still doesn’t seem to be working



def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int):
"""
Expand Down Expand Up @@ -158,4 +194,4 @@
idx = sample_without_replacement(
np.prod(dims), nsamp, random_state=random_state, method=method
)
return np.vstack(np.unravel_index(idx, dims)).T
return np.vstack(np.unravel_index(idx, dims)).T

Check warning on line 197 in src/scanpy/preprocessing/_utils.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_utils.py#L197

Added line #L197 was not covered by tests
Loading