|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import logging
|
4 |
| -from typing import Iterator |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Iterator |
5 | 6 |
|
6 | 7 | import geopandas as gpd
|
7 | 8 | import pandas as pd
|
8 |
| -import xarray as xr |
| 9 | +from anndata import AnnData |
9 | 10 | from datatree import DataTree
|
10 | 11 | from spatialdata import SpatialData
|
11 | 12 | from spatialdata.models import SpatialElement
|
12 | 13 | from spatialdata.transformations import Identity, get_transformation, set_transformation
|
13 | 14 | from xarray import DataArray
|
14 | 15 |
|
15 |
| -from ._constants import SopaAttrs, SopaKeys |
| 16 | +from ._constants import SopaAttrs, SopaFiles, SopaKeys |
16 | 17 |
|
17 | 18 | log = logging.getLogger(__name__)
|
18 | 19 |
|
@@ -112,14 +113,14 @@ def get_intensities(sdata: SpatialData) -> pd.DataFrame | None:
|
112 | 113 | return adata.to_df()
|
113 | 114 |
|
114 | 115 |
|
115 |
| -def iter_scales(image: DataTree) -> Iterator[xr.DataArray]: |
| 116 | +def iter_scales(image: DataTree) -> Iterator[DataArray]: |
116 | 117 | """Iterates through all the scales of a `DataTree`
|
117 | 118 |
|
118 | 119 | Args:
|
119 | 120 | image: a `DataTree`
|
120 | 121 |
|
121 | 122 | Yields:
|
122 |
| - Each scale (as a `xr.DataArray`) |
| 123 | + Each scale (as a `DataArray`) |
123 | 124 | """
|
124 | 125 | assert isinstance(image, DataTree), f"Multiscale iteration is reserved for type DataTree. Found {type(image)}"
|
125 | 126 |
|
@@ -154,22 +155,42 @@ def get_spatial_element(
|
154 | 155 | if len(element_dict) == 1:
|
155 | 156 | key = next(iter(element_dict.keys()))
|
156 | 157 |
|
157 |
| - assert valid_attr is None or element_dict[key].attrs.get( |
| 158 | + assert valid_attr is None or _get_spatialdata_attrs(element_dict[key]).get( |
158 | 159 | valid_attr, True
|
159 | 160 | ), f"Element {key} is not valid for the attribute {valid_attr}."
|
160 | 161 |
|
161 | 162 | return _return_element(element_dict, key, return_key, as_spatial_image)
|
162 | 163 |
|
163 | 164 | assert valid_attr is not None, "Multiple elements found. Provide an element key."
|
164 | 165 |
|
165 |
| - keys = [key for key, element in element_dict.items() if element.attrs.get(valid_attr)] |
| 166 | + keys = [key for key, element in element_dict.items() if _get_spatialdata_attrs(element).get(valid_attr)] |
166 | 167 |
|
167 | 168 | assert len(keys) > 0, f"No element with the attribute {valid_attr}. Provide an element key."
|
168 | 169 | assert len(keys) == 1, f"Multiple valid elements found: {keys}. Provide an element key."
|
169 | 170 |
|
170 | 171 | return _return_element(element_dict, keys[0], return_key, as_spatial_image)
|
171 | 172 |
|
172 | 173 |
|
| 174 | +def _get_spatialdata_attrs(element: SpatialElement) -> dict[str, Any]: |
| 175 | + if isinstance(element, DataTree): |
| 176 | + element = next(iter(element["scale0"].values())) |
| 177 | + return element.attrs.get("spatialdata_attrs", {}) |
| 178 | + |
| 179 | + |
| 180 | +def _update_spatialdata_attrs(element: SpatialElement, attrs: dict): |
| 181 | + if isinstance(element, DataTree): |
| 182 | + for image_scale in iter_scales(element): |
| 183 | + _update_spatialdata_attrs(image_scale, attrs) |
| 184 | + return |
| 185 | + |
| 186 | + old_attrs = element.uns if isinstance(element, AnnData) else element.attrs |
| 187 | + |
| 188 | + if "spatialdata_attrs" not in old_attrs: |
| 189 | + old_attrs["spatialdata_attrs"] = {} |
| 190 | + |
| 191 | + old_attrs["spatialdata_attrs"].update(attrs) |
| 192 | + |
| 193 | + |
173 | 194 | def get_spatial_image(
|
174 | 195 | sdata: SpatialData,
|
175 | 196 | key: str | None = None,
|
@@ -205,3 +226,9 @@ def _return_element(
|
205 | 226 | element = next(iter(element["scale0"].values()))
|
206 | 227 |
|
207 | 228 | return (key, element) if return_key else element
|
| 229 | + |
| 230 | + |
| 231 | +def get_cache_dir(sdata: SpatialData) -> Path: |
| 232 | + assert sdata.is_backed(), "SpatialData not saved on-disk. Save the object, or provide a cache directory." |
| 233 | + |
| 234 | + return sdata.path / SopaFiles.SOPA_CACHE_DIR |
0 commit comments