Skip to content

Commit 578e910

Browse files
committed
Fix some tests broken by the changes
1 parent b43493e commit 578e910

File tree

6 files changed

+25
-4
lines changed

6 files changed

+25
-4
lines changed

src/opencosmo/dataset/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,9 @@ def collect(self) -> Dataset:
788788
789789
If working in an MPI context, all ranks will recieve the same data.
790790
"""
791-
new_handler = self.__handler.collect(self.__state.columns, self.__state.index)
791+
new_handler = self.__handler.collect(
792+
self.__state.unit_handlers.keys(), self.__state.index
793+
)
792794
new_index = ChunkedIndex.from_size(len(new_handler))
793795
new_state = self.__state.with_index(new_index)
794796
return Dataset(

src/opencosmo/dataset/state.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,15 @@ def with_units(
388388
"""
389389

390390
convention_ = UnitConvention(convention)
391+
if (
392+
convention_ == UnitConvention.SCALEFREE
393+
and UnitConvention(self.header.file.unit_convention)
394+
!= UnitConvention.SCALEFREE
395+
):
396+
raise ValueError(
397+
f"Cannot convert units with convention {self.header.file.unit_convention} to convention scalefree"
398+
)
399+
391400
return DatasetState(
392401
self.__unit_applicators,
393402
self.__index,

src/opencosmo/header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def read_header(
257257
f"Error: {e}"
258258
)
259259

260+
print(file_parameters.unit_convention)
260261
origin_parameter_models = origin.get_origin_parameters(file_parameters.origin)
261262
required_origin_params, optional_origin_params = read_origin_parameters(
262263
file, origin_parameter_models

src/opencosmo/parameters/file.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,17 @@ def parse_region(cls, data):
5959
def validate_is_lightcone(cls, value):
6060
return bool(value)
6161

62+
@field_validator("unit_convention", mode="before")
63+
def validate_convention(cls, value):
64+
if isinstance(value, str):
65+
return UnitConvention(value)
66+
return value
67+
6268
@field_serializer("unit_convention")
6369
def serialize_convention(self, value):
64-
return value.value
70+
if isinstance(value, UnitConvention):
71+
return value.value
72+
return value
6573

6674
@model_serializer(mode="wrap")
6775
def serialize_model(self, handle):

src/opencosmo/units/converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def physical_to_scalefree(
183183
return add_littleh(new_value, cosmology, base_unit)
184184

185185

186-
def raise_convert_error(from_: UnitConvention, to_: UnitConvention, *args, **kwargs):
186+
def raise_convert_error(*args, from_: UnitConvention, to_: UnitConvention, **kwargs):
187187
raise ValueError(
188188
f"Units in convention {str(from_)} cannot be converted to units in convention {str(to_)}"
189189
)

test/test_diffsky.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_comoving_to_physical(core_path_487):
2323

2424

2525
def test_comoving_to_scalefree(core_path_487):
26-
with pytest.raises(oc.transformations.units.UnitError):
26+
with pytest.raises(ValueError):
2727
_ = oc.open(core_path_487, synth_cores=True).with_units("scalefree")
2828

2929

@@ -46,6 +46,7 @@ def test_filter_take(core_path_475, core_path_487):
4646

4747
def test_open_multiple_write(core_path_487, core_path_475, tmp_path):
4848
ds = oc.open(core_path_487, core_path_475, synth_cores=True)
49+
print(ds.header.file.unit_convention)
4950
original_length = len(ds)
5051
original_redshift_range = ds.z_range
5152
output = tmp_path / "synth_gals.hdf5"

0 commit comments

Comments
 (0)