Skip to content

Commit d0a541e

Browse files
major API simplification (WIP)
1 parent c501a76 commit d0a541e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+323
-476
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
### Changed
1313
- Import submodules in init (segmentation, io, utils)
14+
- API simplification in progress (new API + tutorial comming soon)
1415

1516
## [1.1.2] - 2024-07-24
1617

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sopa"
3-
version = "1.1.2"
3+
version = "1.1.3"
44
description = "Spatial-omics pipeline and analysis"
55
documentation = "https://gustaveroussy.github.io/sopa"
66
homepage = "https://gustaveroussy.github.io/sopa"
@@ -76,7 +76,7 @@ testpaths = ["tests"]
7676
python_files = "test_*.py"
7777

7878
[tool.black]
79-
line-length = 100
79+
line-length = 120
8080
include = '\.pyi?$'
8181
exclude = '''
8282
/(

sopa/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@
1212
from . import utils
1313
from . import io
1414
from . import segmentation
15+
16+
from .segmentation import tissue_segmentation
17+
from ._sdata import get_spatial_image, get_spatial_element, to_intrinsic

sopa/_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ class SopaKeys:
3333
GEOMETRY_COUNT = "n_components"
3434

3535

36+
class SopaAttrs:
37+
CELL_SEGMENTATION = "for_cell_segmentation"
38+
TISSUE_SEGMENTATION = "for_tissue_segmentation"
39+
BINS_AGGREGATION = "for_bins_aggregation"
40+
GENE_COLUMN = "feature_key"
41+
42+
3643
VALID_DIMENSIONS = ("c", "y", "x")
3744
LOW_AVERAGE_COUNT = 0.01
3845
EPS = 1e-5

sopa/_sdata.py

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from spatialdata.transformations import Identity, get_transformation, set_transformation
1313
from xarray import DataArray
1414

15-
from ._constants import SopaKeys
15+
from ._constants import SopaAttrs, SopaKeys
1616

1717
log = logging.getLogger(__name__)
1818

@@ -40,9 +40,7 @@ def get_boundaries(
4040
if res is not None:
4141
return res
4242

43-
error_message = (
44-
"sdata object has no valid segmentation boundary. Consider running Sopa segmentation first."
45-
)
43+
error_message = "sdata object has no valid segmentation boundary. Consider running Sopa segmentation first."
4644

4745
if not warn:
4846
raise ValueError(error_message)
@@ -51,17 +49,13 @@ def get_boundaries(
5149
return (None, None) if return_key else None
5250

5351

54-
def _try_get_boundaries(
55-
sdata: SpatialData, shapes_key: str, return_key: bool
56-
) -> gpd.GeoDataFrame | None:
52+
def _try_get_boundaries(sdata: SpatialData, shapes_key: str, return_key: bool) -> gpd.GeoDataFrame | None:
5753
"""Try to get a cell boundaries for a given `shapes_key`"""
5854
if shapes_key in sdata.shapes:
5955
return (shapes_key, sdata[shapes_key]) if return_key else sdata[shapes_key]
6056

6157

62-
def get_intrinsic_cs(
63-
sdata: SpatialData, element: SpatialElement | str, name: str | None = None
64-
) -> str:
58+
def get_intrinsic_cs(sdata: SpatialData, element: SpatialElement | str, name: str | None = None) -> str:
6559
"""Gets the name of the intrinsic coordinate system of an element
6660
6761
Args:
@@ -86,9 +80,7 @@ def get_intrinsic_cs(
8680
return name
8781

8882

89-
def to_intrinsic(
90-
sdata: SpatialData, element: SpatialElement | str, element_cs: SpatialElement | str
91-
) -> SpatialElement:
83+
def to_intrinsic(sdata: SpatialData, element: SpatialElement | str, element_cs: SpatialElement | str) -> SpatialElement:
9284
"""Transforms a `SpatialElement` into the intrinsic coordinate system of another `SpatialElement`
9385
9486
Args:
@@ -105,32 +97,6 @@ def to_intrinsic(
10597
return sdata.transform_element_to_coordinate_system(element, cs)
10698

10799

108-
def get_key(sdata: SpatialData, attr: str, key: str | None = None):
109-
if key is not None:
110-
return key
111-
112-
elements = getattr(sdata, attr)
113-
114-
if not len(elements):
115-
return None
116-
117-
assert (
118-
len(elements) == 1
119-
), f"Trying to get an element key of `sdata.{attr}`, but it contains multiple values and no dict key was provided"
120-
121-
return next(iter(elements.keys()))
122-
123-
124-
def get_element(sdata: SpatialData, attr: str, key: str | None = None):
125-
key = get_key(sdata, attr, key)
126-
return sdata[key] if key is not None else None
127-
128-
129-
def get_item(sdata: SpatialData, attr: str, key: str | None = None):
130-
key = get_key(sdata, attr, key)
131-
return key, sdata[key] if key is not None else None
132-
133-
134100
def get_intensities(sdata: SpatialData) -> pd.DataFrame | None:
135101
"""Gets the intensity dataframe of shape `n_obs x n_channels`"""
136102
assert SopaKeys.TABLE in sdata.tables, f"No '{SopaKeys.TABLE}' found in sdata.tables"
@@ -155,35 +121,87 @@ def iter_scales(image: DataTree) -> Iterator[xr.DataArray]:
155121
Yields:
156122
Each scale (as a `xr.DataArray`)
157123
"""
158-
assert isinstance(
159-
image, DataTree
160-
), f"Multiscale iteration is reserved for type DataTree. Found {type(image)}"
124+
assert isinstance(image, DataTree), f"Multiscale iteration is reserved for type DataTree. Found {type(image)}"
161125

162126
for scale in image:
163127
yield next(iter(image[scale].values()))
164128

165129

130+
def get_spatial_element(
131+
element_dict: dict[str, SpatialElement],
132+
key: str | None = None,
133+
valid_attr: str | None = None,
134+
return_key: bool = False,
135+
as_spatial_image: bool = False,
136+
) -> SpatialElement | tuple[str, SpatialElement]:
137+
"""Gets an element from a SpatialData object.
138+
139+
Args:
140+
sdata: SpatialData object.
141+
key: Optional element key. If `None`, returns the only element (if only one), or tries to find an element with `valid_attr`.
142+
return_key: Whether to also return the key of the element.
143+
valid_attr: Attribute that the element must have to be considered valid.
144+
as_spatial_image: Whether to return the element as a `SpatialImage` (if it is a `DataTree`)
145+
146+
Returns:
147+
If `return_key` is False, only the element is returned, else a tuple `(element_key, element)`
148+
"""
149+
assert len(element_dict), "No spatial element was found in the dict."
150+
151+
if key is not None:
152+
return _return_element(element_dict, key, return_key, as_spatial_image)
153+
154+
if len(element_dict) == 1:
155+
key = next(iter(element_dict.keys()))
156+
157+
assert valid_attr is None or element_dict[key].attrs.get(
158+
valid_attr, True
159+
), f"Element {key} is not valid for the attribute {valid_attr}."
160+
161+
return _return_element(element_dict, key, return_key, as_spatial_image)
162+
163+
assert valid_attr is not None, "Multiple elements found. Provide an element key."
164+
165+
keys = [key for key, element in element_dict.items() if element.attrs.get(valid_attr)]
166+
167+
assert len(keys) > 0, f"No element with the attribute {valid_attr}. Provide an element key."
168+
assert len(keys) == 1, f"Multiple valid elements found: {keys}. Provide an element key."
169+
170+
return _return_element(element_dict, keys[0], return_key, as_spatial_image)
171+
172+
166173
def get_spatial_image(
167-
sdata: SpatialData, key: str | None = None, return_key: bool = False
174+
sdata: SpatialData,
175+
key: str | None = None,
176+
return_key: bool = False,
177+
valid_attr: str = SopaAttrs.CELL_SEGMENTATION,
168178
) -> DataArray | tuple[str, DataArray]:
169179
"""Gets a DataArray from a SpatialData object (if the image has multiple scale, the `scale0` is returned)
170180
171181
Args:
172182
sdata: SpatialData object.
173-
key: Optional image key. If `None`, returns the only image (if only one), or raises an error.
183+
key: Optional image key. If `None`, returns the only image (if only one), or tries to find an image with `valid_attr`.
174184
return_key: Whether to also return the key of the image.
185+
valid_attr: Attribute that the image must have to be considered valid.
175186
176187
Returns:
177188
If `return_key` is False, only the image is returned, else a tuple `(image_key, image)`
178189
"""
179-
key = get_key(sdata, "images", key)
190+
return get_spatial_element(
191+
sdata.images,
192+
key=key,
193+
valid_attr=valid_attr,
194+
return_key=return_key,
195+
as_spatial_image=True,
196+
)
197+
180198

181-
assert key is not None, "One image in `sdata.images` is required"
199+
def _return_element(
200+
element_dict: dict[str, SpatialElement], key: str, return_key: bool, as_spatial_image: bool
201+
) -> SpatialElement | tuple[str, SpatialElement]:
202+
element = element_dict[key]
182203

183-
image = sdata.images[key]
184-
if isinstance(image, DataTree):
185-
image = next(iter(image["scale0"].values()))
204+
if as_spatial_image and isinstance(element, DataTree):
205+
element = next(iter(element["scale0"].values()))
186206

187-
if return_key:
188-
return key, image
189-
return image
207+
return (key, element) if return_key else element

sopa/annotation/tangram/run.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ def init_obsm(self, level: int):
124124
)
125125

126126
def get_hard_labels(self, df: pd.DataFrame) -> pd.Series:
127-
df = df.clip(
128-
df.quantile(1 - self.clip_percentile), df.quantile(self.clip_percentile), axis=1
129-
)
127+
df = df.clip(df.quantile(1 - self.clip_percentile), df.quantile(self.clip_percentile), axis=1)
130128
df = (df - df.min()) / (df.max() - df.min())
131129
return df.idxmax(1)
132130

@@ -138,9 +136,7 @@ def pp_adata(self, ad_sp_: AnnData, ad_sc_: AnnData, split: np.ndarray) -> AnnDa
138136
sc.pp.filter_genes(ad_sp_split, min_cells=1)
139137

140138
# Calculate uniform density prior as 1/number_of_spots
141-
ad_sp_split.obs["uniform_density"] = (
142-
np.ones(ad_sp_split.X.shape[0]) / ad_sp_split.X.shape[0]
143-
)
139+
ad_sp_split.obs["uniform_density"] = np.ones(ad_sp_split.X.shape[0]) / ad_sp_split.X.shape[0]
144140

145141
# Calculate rna_count_based density prior as % of rna molecule count
146142
rna_count_per_spot = np.array(ad_sp_split.X.sum(axis=1)).squeeze()
@@ -157,8 +153,7 @@ def pp_adata(self, ad_sp_: AnnData, ad_sc_: AnnData, split: np.ndarray) -> AnnDa
157153
)
158154

159155
selection = list(
160-
set(ad_sp_split.var_names[ad_sp_split.var.counts > 0])
161-
& set(ad_sc_.var_names[ad_sc_.var.counts > 0])
156+
set(ad_sp_split.var_names[ad_sp_split.var.counts > 0]) & set(ad_sc_.var_names[ad_sc_.var.counts > 0])
162157
)
163158

164159
assert len(

sopa/cli/annotate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
def fluorescence(
1414
sdata_path: str = typer.Argument(help=SDATA_HELPER),
1515
marker_cell_dict: str = typer.Option(callback=ast.literal_eval),
16-
cell_type_key: str = typer.Option(
17-
"cell_type", help="Key added in `adata.obs` corresponding to the cell type"
18-
),
16+
cell_type_key: str = typer.Option("cell_type", help="Key added in `adata.obs` corresponding to the cell type"),
1917
):
2018
"""Simple annotation based on fluorescence, where each provided channel corresponds to one cell type.
2119
@@ -39,9 +37,7 @@ def fluorescence(
3937
@app_annotate.command()
4038
def tangram(
4139
sdata_path: str = typer.Argument(help=SDATA_HELPER),
42-
sc_reference_path: str = typer.Option(
43-
help="Path to the scRNAseq annotated reference, as a `.h5ad` file"
44-
),
40+
sc_reference_path: str = typer.Option(help="Path to the scRNAseq annotated reference, as a `.h5ad` file"),
4541
cell_type_key: str = typer.Option(help="Key of `adata_ref.obs` containing the cell-types"),
4642
reference_preprocessing: str = typer.Option(
4743
None,

sopa/cli/app.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
name="segmentation",
2929
help="Perform cell segmentation on patches. NB: for `baysor`, use directly the `baysor` command line.",
3030
)
31-
app.add_typer(
32-
app_resolve, name="resolve", help="Resolve the segmentation conflicts over patches overlaps"
33-
)
31+
app.add_typer(app_resolve, name="resolve", help="Resolve the segmentation conflicts over patches overlaps")
3432
app.add_typer(
3533
app_patchify,
3634
name="patchify",
@@ -77,9 +75,7 @@ def read(
7775

7876
io.standardize._check_can_write_zarr(sdata_path)
7977

80-
assert (
81-
technology is not None or config_path is not None
82-
), "Provide the argument `--technology` or `--config-path`"
78+
assert technology is not None or config_path is not None, "Provide the argument `--technology` or `--config-path`"
8379

8480
if config_path is not None:
8581
assert not kwargs, "Provide either a path to a config, or some kwargs, but not both"
@@ -114,9 +110,7 @@ def crop(
114110
None,
115111
help="List of channel names to be displayed. Optional if there are already only 1 or 3 channels",
116112
),
117-
scale_factor: float = typer.Option(
118-
10, help="Resize the image by this value (high value for a lower memory usage)"
119-
),
113+
scale_factor: float = typer.Option(10, help="Resize the image by this value (high value for a lower memory usage)"),
120114
margin_ratio: float = typer.Option(
121115
0.1, help="Ratio of the image margin on the display (compared to the image size)"
122116
),
@@ -163,16 +157,12 @@ def aggregate(
163157
None,
164158
help="Column of the transcript dataframe representing the gene names. If not provided, it will not compute transcript count",
165159
),
166-
average_intensities: bool = typer.Option(
167-
False, help="Whether to average the channel intensities inside each cell"
168-
),
160+
average_intensities: bool = typer.Option(False, help="Whether to average the channel intensities inside each cell"),
169161
expand_radius_ratio: float = typer.Option(
170162
default=0,
171163
help="Cells polygons will be expanded by `expand_radius_ratio * mean_radius` for channels averaging **only**. This help better aggregate boundary stainings",
172164
),
173-
min_transcripts: int = typer.Option(
174-
0, help="Cells with less transcript than this integer will be filtered"
175-
),
165+
min_transcripts: int = typer.Option(0, help="Cells with less transcript than this integer will be filtered"),
176166
min_intensity_ratio: float = typer.Option(
177167
0,
178168
help="Cells whose mean channel intensity is less than `min_intensity_ratio * quantile_90` will be filtered",

sopa/cli/check.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ def _open_config(path: str) -> dict:
1616
with open(path, "r") as f:
1717
return yaml.safe_load(f)
1818
except:
19-
log.warn(
20-
f"Config file '{path}' could't be read. Make sure that the file exist and that it is a YAML file"
21-
)
19+
log.warn(f"Config file '{path}' could't be read. Make sure that the file exist and that it is a YAML file")
2220
return
2321

2422

@@ -99,8 +97,7 @@ def _check_dict(config: dict, d: dict, log, prefix: str = "config"):
9997
break
10098
else:
10199
display = "\n - ".join(
102-
element if isinstance(element, str) else " AND ".join(element)
103-
for element in values
100+
element if isinstance(element, str) else " AND ".join(element) for element in values
104101
)
105102
log.warn(f"One of these element must be in {prefix}['{key}']:\n - {display}")
106103

sopa/cli/explorer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66

77
app_explorer = typer.Typer()
88

9-
PIXELSIZE_DEPRECATED = (
10-
"`pixelsize` is deprecated and will be removed in future versions. Use `pixel_size` instead."
11-
)
9+
PIXELSIZE_DEPRECATED = "`pixelsize` is deprecated and will be removed in future versions. Use `pixel_size` instead."
1210

1311

1412
@app_explorer.command()
@@ -18,9 +16,7 @@ def write(
1816
None,
1917
help="Path to a directory where Xenium Explorer's outputs will be saved. By default, writes to the same path as `sdata_path` but with the `.explorer` suffix",
2018
),
21-
gene_column: str = typer.Option(
22-
None, help="Column name of the points dataframe containing the gene names"
23-
),
19+
gene_column: str = typer.Option(None, help="Column name of the points dataframe containing the gene names"),
2420
shapes_key: str = typer.Option(
2521
None,
2622
help="Sdata key for the boundaries. By default, uses the baysor boundaires, else the cellpose boundaries",
@@ -134,6 +130,4 @@ def add_aligned(
134130
sdata = spatialdata.read_zarr(sdata_path)
135131
image = io.ome_tif(image_path, as_image=True)
136132

137-
align(
138-
sdata, image, transformation_matrix_path, overwrite=overwrite, image_key=original_image_key
139-
)
133+
align(sdata, image, transformation_matrix_path, overwrite=overwrite, image_key=original_image_key)

0 commit comments

Comments
 (0)