Skip to content

Commit 7c7c155

Browse files
Storing patches embeddings as an AnnData object instead of images
1 parent 1aa4a43 commit 7c7c155

File tree

4 files changed

+31
-39
lines changed

4 files changed

+31
-39
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
- Added H-optimus-0 model for H&E patches embeddings @stergioc (#208)
55
- Can provide `qv_threshold` argument to the Xenium reader to filter low quality transcripts @callum-jpg (#210)
66

7+
### Changed
8+
- Storing patches embeddings as an `AnnData` object instead of images
9+
710
### Fixed
811
- Right sorting scales for WSI reader with openslide backend @stergioc (#209)
912
- When a polygon cannot be simplified (for the Xenium Explorer), convert it to a circle (#206)

sopa/patches/_inference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import torch
6-
from spatialdata.transformations import Scale, Sequence, get_transformation
76
from xarray import DataArray, DataTree
87

98
from . import models
@@ -76,13 +75,6 @@ def infer_bboxes(self, bboxes: np.ndarray) -> torch.Tensor:
7675

7776
return embedding.cpu() # shape (B, output_dim)
7877

79-
def get_patches_transformations(self, patch_overlap: float) -> dict[str, Sequence]:
80-
image_transformations = get_transformation(self.image, get_all=True)
81-
82-
patch_step = self.patch_width - patch_overlap
83-
to_image = Sequence([Scale([patch_step, patch_step], axes=("x", "y"))])
84-
return {cs: to_image.compose_with(t) for cs, t in image_transformations.items()}
85-
8678

8779
def _get_extraction_parameters(
8880
image: DataArray | DataTree, level: int | None, magnification: int | None

sopa/patches/cluster.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from typing import Callable
22

3-
import anndata
4-
import numpy as np
53
import scanpy as sc
4+
from anndata import AnnData
65
from spatialdata import SpatialData
7-
from xarray import DataArray
86

9-
from sopa._constants import SopaKeys
107

11-
12-
def leiden_clustering(X: np.ndarray, flavor: str = "igraph", **kwargs):
13-
adata = anndata.AnnData(X=X)
8+
def leiden_clustering(adata: AnnData, flavor: str = "igraph", **kwargs):
9+
adata = adata.copy()
1410
sc.pp.pca(adata)
1511
sc.pp.neighbors(adata)
1612
sc.tl.leiden(adata, flavor=flavor, **kwargs)
@@ -23,8 +19,8 @@ def leiden_clustering(X: np.ndarray, flavor: str = "igraph", **kwargs):
2319

2420

2521
def cluster_embeddings(
26-
sdata: SpatialData,
27-
element: DataArray | str,
22+
sdata: SpatialData | None,
23+
element: AnnData | str,
2824
method: Callable | str = "leiden",
2925
key_added: str = "cluster",
3026
**method_kwargs: str,
@@ -35,23 +31,18 @@ def cluster_embeddings(
3531
The clusters are added to the `key_added` column of the "inference_patches" shapes (`key_added='cluster'` by default).
3632
3733
Args:
38-
sdata: A `SpatialData` object
39-
element: The `DataArray` containing the embeddings, or the name of the element
40-
method: Callable that takes as an input an array of size `(n_patches x embedding_size)` and returns an array of clusters of size `n_patches`, or an available method name (`leiden`)
41-
key_added: The key containing the clusters to be added to the patches `GeoDataFrame`
34+
sdata: A `SpatialData` object. Can be `None` if element is an `AnnData` object.
35+
element: The `AnnData` containing the embeddings, or the name of the element
36+
method: Callable that takes as an AnnData object and returns an array of clusters of size `n_obs`, or an available method name (`leiden`)
37+
key_added: The key containing the clusters to be added to the `element.obs`
4238
method_kwargs: kwargs provided to the method callable
4339
"""
4440
if isinstance(element, str):
45-
element: DataArray = sdata.images[element]
41+
element: AnnData = sdata.tables[element]
4642

4743
if isinstance(method, str):
4844
assert method in METHODS_DICT, f"Method {method} is not available. Use one of: {', '.join(METHODS_DICT.keys())}"
4945
method = METHODS_DICT[method]
5046

51-
gdf_patches = sdata[SopaKeys.EMBEDDINGS_PATCHES]
52-
53-
ilocs = np.array(list(gdf_patches.ilocs))
54-
embeddings = element.compute().data[:, ilocs[:, 1], ilocs[:, 0]].T
55-
56-
gdf_patches[key_added] = method(embeddings, **method_kwargs)
57-
gdf_patches[key_added] = gdf_patches[key_added].astype("category")
47+
element.obs[key_added] = method(element, **method_kwargs)
48+
element.obs[key_added] = element.obs[key_added].astype("category")

sopa/patches/infer.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
from typing import Callable
33

4-
import numpy as np
54
import tqdm
5+
from anndata import AnnData
66
from spatialdata import SpatialData
7-
from spatialdata.models import Image2DModel
7+
from spatialdata.models import TableModel
88
from xarray import DataArray, DataTree
99

1010
from .._constants import SopaAttrs, SopaKeys
@@ -76,17 +76,23 @@ def compute_embeddings(
7676
if len(predictions.shape) == 1:
7777
predictions = torch.unsqueeze(predictions, 1)
7878

79-
output_image = np.zeros((predictions.shape[1], *patches.shape), dtype=np.float32)
80-
for (loc_x, loc_y), pred in zip(patches.ilocs, predictions):
81-
output_image[:, loc_y, loc_x] = pred
79+
patches.add_shapes(key_added=SopaKeys.EMBEDDINGS_PATCHES)
8280

83-
output_image = DataArray(output_image, dims=("c", "y", "x"))
84-
output_image = Image2DModel.parse(output_image, transformations=infer.get_patches_transformations(patch_overlap))
81+
gdf = sdata[SopaKeys.EMBEDDINGS_PATCHES]
8582

86-
key_added = key_added or f"{infer.model_str}_embeddings"
87-
add_spatial_element(sdata, key_added, output_image)
83+
adata = AnnData(predictions.numpy())
84+
adata.obs["region"] = SopaKeys.EMBEDDINGS_PATCHES
85+
adata.obs["instance"] = gdf.index.values
86+
adata = TableModel.parse(
87+
adata,
88+
region=SopaKeys.EMBEDDINGS_PATCHES,
89+
region_key="region",
90+
instance_key="instance",
91+
)
92+
adata.obsm["spatial"] = gdf.centroid.get_coordinates().values
8893

89-
patches.add_shapes(key_added=SopaKeys.EMBEDDINGS_PATCHES)
94+
key_added = key_added or f"{infer.model_str}_embeddings"
95+
add_spatial_element(sdata, key_added, adata)
9096

9197

9298
def _get_image_for_inference(sdata: SpatialData, image_key: str | None = None) -> DataArray | DataTree:

0 commit comments

Comments
 (0)