Skip to content

Commit e88c24e

Browse files
Remove xarray-schema; validate() are now class methods (#1066)
* remove xarray-schema; validate() are now class methods * small code cleanup * remove unused model instances; fix validation for table (should not modify the object) * fix test_set_table_annotates_spatialelement * remove deprecation warning for raster and table model __init__
1 parent 3cdf3d8 commit e88c24e

File tree

12 files changed

+180
-171
lines changed

12 files changed

+180
-171
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ dependencies = [
4848
"typing_extensions>=4.8.0",
4949
"universal_pathlib>=0.2.6",
5050
"xarray>=2024.10.0",
51-
"xarray-schema",
5251
"xarray-spatial>=0.3.5",
5352
"zarr>=3.0.0",
5453
]

src/spatialdata/_core/_deepcopy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _(element: DataTree) -> DataTree:
7979
msi[key][variable].data = from_array(msi[key][variable].data)
8080
element[key][variable].data = from_array(element[key][variable].data)
8181
assert model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
82-
model().validate(msi)
82+
model.validate(msi)
8383
return msi
8484

8585

src/spatialdata/_core/_elements.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
7272
raise TypeError(f"Unknown element type with schema: {schema!r}.")
7373
ndim = len(get_axes_names(value))
7474
if ndim == 3:
75-
Image2DModel().validate(value)
75+
Image2DModel.validate(value)
7676
super().__setitem__(key, value)
7777
elif ndim == 4:
78-
Image3DModel().validate(value)
78+
Image3DModel.validate(value)
7979
super().__setitem__(key, value)
8080
else:
8181
NotImplementedError("TODO: implement for ndim > 4.")
@@ -89,10 +89,10 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
8989
raise TypeError(f"Unknown element type with schema: {schema!r}.")
9090
ndim = len(get_axes_names(value))
9191
if ndim == 2:
92-
Labels2DModel().validate(value)
92+
Labels2DModel.validate(value)
9393
super().__setitem__(key, value)
9494
elif ndim == 3:
95-
Labels3DModel().validate(value)
95+
Labels3DModel.validate(value)
9696
super().__setitem__(key, value)
9797
else:
9898
NotImplementedError("TODO: implement for ndim > 3.")
@@ -104,7 +104,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
104104
schema = get_model(value)
105105
if schema != ShapesModel:
106106
raise TypeError(f"Unknown element type with schema: {schema!r}.")
107-
ShapesModel().validate(value)
107+
ShapesModel.validate(value)
108108
super().__setitem__(key, value)
109109

110110

@@ -114,7 +114,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
114114
schema = get_model(value)
115115
if schema != PointsModel:
116116
raise TypeError(f"Unknown element type with schema: {schema!r}.")
117-
PointsModel().validate(value)
117+
PointsModel.validate(value)
118118
super().__setitem__(key, value)
119119

120120

@@ -124,5 +124,5 @@ def __setitem__(self, key: str, value: AnnData) -> None:
124124
schema = get_model(value)
125125
if schema != TableModel:
126126
raise TypeError(f"Unknown element type with schema: {schema!r}.")
127-
TableModel().validate(value)
127+
TableModel.validate(value)
128128
super().__setitem__(key, value)

src/spatialdata/_core/operations/rasterize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def rasterize_images_labels(
602602
set_transformation(transformed_data, sequence, target_coordinate_system)
603603

604604
transformed_data = compute_coordinates(transformed_data)
605-
schema().validate(transformed_data)
605+
schema.validate(transformed_data)
606606
return transformed_data
607607

608608

src/spatialdata/_core/operations/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _(
333333
to_coordinate_system=to_coordinate_system,
334334
)
335335
transformed_data = compute_coordinates(transformed_data)
336-
schema().validate(transformed_data)
336+
schema.validate(transformed_data)
337337
return transformed_data
338338

339339

@@ -419,7 +419,7 @@ def _(
419419
to_coordinate_system=to_coordinate_system,
420420
)
421421
transformed_data = compute_coordinates(transformed_data)
422-
schema().validate(transformed_data)
422+
schema.validate(transformed_data)
423423
return transformed_data
424424

425425

src/spatialdata/_core/spatialdata.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@
5757
SpatialDataFormatType,
5858
)
5959

60-
# schema for elements
61-
Label2D_s = Labels2DModel()
62-
Label3D_s = Labels3DModel()
63-
Image2D_s = Image2DModel()
64-
Image3D_s = Image3DModel()
65-
Shape_s = ShapesModel()
66-
Point_s = PointsModel()
67-
Table_s = TableModel()
68-
6960

7061
class SpatialData:
7162
"""
@@ -199,7 +190,7 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None:
199190
UserWarning
200191
The dtypes of the instance key column in the table and the annotation target do not match.
201192
"""
202-
TableModel().validate(table)
193+
TableModel.validate(table)
203194
if TableModel.ATTRS_KEY in table.uns:
204195
region, _, instance_key = get_table_keys(table)
205196
region = region if isinstance(region, list) else [region]
@@ -349,8 +340,13 @@ def _set_table_annotation_target(
349340
ValueError
350341
If `instance_key` is not present in the `table.obs` columns.
351342
"""
352-
TableModel()._validate_set_region_key(table, region_key)
353-
TableModel()._validate_set_instance_key(table, instance_key)
343+
old_attrs = table.uns.get(TableModel.ATTRS_KEY)
344+
# _validate_set_region_key and _validate_set_instance_key will raise an error if table.uns[ATTRS_KEY] is None,
345+
# so let's initialize it here. Below it will be replaced with the actual metadata.
346+
if old_attrs is None:
347+
table.uns[TableModel.ATTRS_KEY] = {}
348+
TableModel._validate_set_region_key(table, region_key)
349+
TableModel._validate_set_instance_key(table, instance_key)
354350
attrs = {
355351
TableModel.REGION_KEY: region,
356352
TableModel.REGION_KEY_KEY: region_key,
@@ -393,8 +389,8 @@ def _change_table_annotation_target(
393389
attrs = table.uns[TableModel.ATTRS_KEY]
394390
table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY)
395391

396-
TableModel()._validate_set_region_key(table, region_key)
397-
TableModel()._validate_set_instance_key(table, instance_key)
392+
TableModel._validate_set_region_key(table, region_key)
393+
TableModel._validate_set_instance_key(table, instance_key)
398394
check_target_region_column_symmetry(table, table_region_key, region)
399395
attrs[TableModel.REGION_KEY] = region
400396

@@ -1822,7 +1818,7 @@ def tables(self, tables: dict[str, AnnData]) -> None:
18221818
self._shared_keys = self._shared_keys - set(self._tables.keys())
18231819
self._tables = Tables(shared_keys=self._shared_keys)
18241820
for k, v in tables.items():
1825-
TableModel().validate(v)
1821+
TableModel.validate(v)
18261822
self._tables[k] = v
18271823

18281824
@staticmethod

src/spatialdata/_io/io_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def write_table(
5656
) -> None:
5757
if TableModel.ATTRS_KEY in table.uns:
5858
region, region_key, instance_key = get_table_keys(table)
59-
TableModel().validate(table)
59+
TableModel.validate(table)
6060
else:
6161
region, region_key, instance_key = (None, None, None)
6262

src/spatialdata/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _error_message_add_element() -> None:
308308

309309

310310
def _check_match_length_channels_c_dim(
311-
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str]
311+
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str, ...]
312312
) -> list[str]:
313313
"""
314314
Check whether channel names `c_coords` are of equal length to the `c` dimension of the data.

src/spatialdata/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s
401401

402402
# get_model cannot be used due to circular import so get_axes_names is used instead
403403
if model in [Image2DModel, Image3DModel]:
404-
channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims.dims) # type: ignore[union-attr]
404+
channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims) # type: ignore[union-attr]
405405
if isinstance(element, DataArray):
406406
element = element.assign_coords(c=channel_names)
407407
else:

0 commit comments

Comments
 (0)