From 09acb77de7a6f6d6c624e631e2f8fdc418163847 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Wed, 9 Jul 2025 16:20:05 -0500 Subject: [PATCH 1/5] Changes to importing analysis tools --- opencosmo/analysis/__init__.py | 44 +++++-- opencosmo/analysis/yt_viz.py | 232 ++++++++++++++++----------------- 2 files changed, 150 insertions(+), 126 deletions(-) diff --git a/opencosmo/analysis/__init__.py b/opencosmo/analysis/__init__.py index ff978d31..067e52ca 100644 --- a/opencosmo/analysis/__init__.py +++ b/opencosmo/analysis/__init__.py @@ -1,11 +1,6 @@ -from .yt_utils import create_yt_dataset -from .yt_viz import ( - ProjectionPlot, SlicePlot, ParticleProjectionPlot, - ProfilePlot, PhasePlot, - visualize_halo, halo_projection_array, -) - -__all__ = [ +# ruff: noqa +__all__ = [] +yt_tools = [ "create_yt_dataset", "ProjectionPlot", "SlicePlot", @@ -15,3 +10,36 @@ "visualize_halo", "halo_projection_array", ] + + +try: + from .yt_utils import create_yt_dataset + from .yt_viz import ( + ParticleProjectionPlot, + PhasePlot, + ProfilePlot, + ProjectionPlot, + SlicePlot, + halo_projection_array, + visualize_halo, + ) + + __all__.extend(yt_tools) + +except ImportError: # User has not installed yt tools + pass + +""" +Right now, we have only have two analysis modules so we can handle them directly. In the +future we will need to implement a more robust system that handles things automatically. +""" + + +def __getattr__(name): + if name in yt_tools: + raise ImportError( + "You tried to import one of the OpenCosmo YT tools, but your python " + "environment does not have the necessary dependencies. You can do install " + "them with `pip install opencosmo[analysis]`" + ) + raise ImportError(f"Cannot import name '{name}' from opencosmo.analysis") diff --git a/opencosmo/analysis/yt_viz.py b/opencosmo/analysis/yt_viz.py index 45d995f6..2b23cd50 100644 --- a/opencosmo/analysis/yt_viz.py +++ b/opencosmo/analysis/yt_viz.py @@ -1,19 +1,18 @@ -import numpy as np - -from matplotlib.colors import LogNorm # type: ignore -from matplotlib.figure import Figure # type: ignore -from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar # type: ignore +from typing import Any, Dict, Optional, Tuple, Union -import yt # type: ignore -from unyt import unyt_quantity # type: ignore -from yt.visualization.base_plot_types import get_multi_plot # type: ignore -from yt.visualization.plot_window import PlotWindow, NormalPlot # type: ignore +import numpy as np +import yt # type: ignore +from matplotlib.colors import LogNorm # type: ignore +from matplotlib.figure import Figure # type: ignore +from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar # type: ignore +from unyt import unyt_quantity # type: ignore +from yt.visualization.base_plot_types import get_multi_plot # type: ignore +from yt.visualization.plot_window import NormalPlot, PlotWindow # type: ignore import opencosmo as oc -from opencosmo.analysis import create_yt_dataset - -from typing import Any, Dict, Optional, Tuple, Union +from opencosmo.analysis import create_yt_dataset +# ruff: noqa: E501 def ParticleProjectionPlot(*args, **kwargs) -> NormalPlot: @@ -130,7 +129,7 @@ def visualize_halo( halo_id: int, data: oc.StructureCollection, length_scale: Optional[str] = "top left", - width: float = 4.0 + width: float = 4.0, ) -> Figure: """ Creates a figure showing particle projections of dark matter, stars, gas, and/or gas temperature @@ -138,9 +137,9 @@ def visualize_halo( create a horizontal arrangement with only the particles/fields that are present. Otherwise, creates a 2x2-panel figure. Each panel is an 800x800 pixel array. - To customize the arrangement of panels, fields, colormaps, etc., see + To customize the arrangement of panels, fields, colormaps, etc., see :func:`halo_projection_array`. - + Parameters ---------- @@ -181,75 +180,71 @@ def visualize_halo( "cmaps": [], } - ptypes = [key.removesuffix('_particles') - for key in data.keys() if key.endswith('_particles')] + ptypes = [ + key.removesuffix("_particles") + for key in data.keys() + if key.endswith("_particles") + ] any_supported = False if "dm" in ptypes: - any_supported = True - params["fields"].append(("dm", "particle_mass")) - params["weight_fields"].append(None) - params["zlims"].append(None) - params["labels"].append("Dark Matter") - params["cmaps"].append("gray") + any_supported = True + params["fields"].append(("dm", "particle_mass")) + params["weight_fields"].append(None) + params["zlims"].append(None) + params["labels"].append("Dark Matter") + params["cmaps"].append("gray") elif "gravity" in ptypes: - any_supported = True - params["fields"].append(("gravity", "particle_mass")) - params["weight_fields"].append(None) - params["zlims"].append(None) - params["labels"].append("Dark Matter") - params["cmaps"].append("gray") + any_supported = True + params["fields"].append(("gravity", "particle_mass")) + params["weight_fields"].append(None) + params["zlims"].append(None) + params["labels"].append("Dark Matter") + params["cmaps"].append("gray") if "star" in ptypes: - any_supported = True - params["fields"].append(("star", "particle_mass")) - params["weight_fields"].append(None) - params["zlims"].append(None) - params["labels"].append("Stars") - params["cmaps"].append("bone") + any_supported = True + params["fields"].append(("star", "particle_mass")) + params["weight_fields"].append(None) + params["zlims"].append(None) + params["labels"].append("Stars") + params["cmaps"].append("bone") if "gas" in ptypes: - any_supported = True - params["fields"].append(("gas", "particle_mass")) - params["weight_fields"].append(None) - params["zlims"].append(None) - params["labels"].append("Gas") - params["cmaps"].append("viridis") - # temperature field should always exist if gas - # particles are present - params["fields"].append(("gas", "temperature")) - params["weight_fields"].append(("gas", "density")) - params["zlims"].append((1e7, 1e8)) - params["labels"].append("Gas Temperature") - params["cmaps"].append("inferno") + any_supported = True + params["fields"].append(("gas", "particle_mass")) + params["weight_fields"].append(None) + params["zlims"].append(None) + params["labels"].append("Gas") + params["cmaps"].append("viridis") + # temperature field should always exist if gas + # particles are present + params["fields"].append(("gas", "temperature")) + params["weight_fields"].append(("gas", "density")) + params["zlims"].append((1e7, 1e8)) + params["labels"].append("Gas Temperature") + params["cmaps"].append("inferno") if not any_supported: raise RuntimeError( "No compatible particle types present in dataset for this function. " - "Possible options are \"dm\", \"gravity\", \"star\", and \"gas\"." + 'Possible options are "dm", "gravity", "star", and "gas".' ) - if len(params["fields"]) == 4: # if 4 fields, make a 2x2 figure - halo_ids = ( [halo_id, halo_id], [halo_id, halo_id] ) - params = { - key: (value[:2], value[2:]) - for key, value in params.items() - } + halo_ids = ([halo_id, halo_id], [halo_id, halo_id]) + params = {key: (value[:2], value[2:]) for key, value in params.items()} - else: + else: # otherwise, do 1xN - halo_ids = ( np.shape(params["fields"])[0]*[halo_id] ) - params = { - key: [value] - for key, value in params.items() - } - + halo_ids = np.shape(params["fields"])[0] * [halo_id] + params = {key: [value] for key, value in params.items()} - return halo_projection_array(halo_ids, data, params=params, - length_scale=length_scale, width=width) + return halo_projection_array( + halo_ids, data, params=params, length_scale=length_scale, width=width + ) def halo_projection_array( @@ -262,7 +257,7 @@ def halo_projection_array( params: Optional[Dict[str, Any]] = None, length_scale: Optional[str] = None, smooth_gas_fields: bool = False, - width: float = 6.0 + width: float = 6.0, ) -> Figure: """ Creates a multipanel figure of projections for different fields and/or halos. @@ -330,7 +325,6 @@ def halo_projection_array( A Matplotlib Figure object. """ - halo_ids = np.atleast_2d(halo_ids) # determine shape of figure @@ -340,36 +334,29 @@ def halo_projection_array( if weight_field is None: weight_field_ = np.full(fig_shape, None) else: - weight_field_ = np.reshape( - [weight_field for _ in range(np.prod(fig_shape))], - (fig_shape[0], fig_shape[1], 2) + weight_field_ = np.reshape( + [weight_field for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 2), ) if zlim is None: zlim_ = np.full(fig_shape, None) else: zlim_ = np.reshape( - [zlim for _ in range(np.prod(fig_shape))], - (fig_shape[0], fig_shape[1], 2) + [zlim for _ in range(np.prod(fig_shape))], (fig_shape[0], fig_shape[1], 2) ) default_params = { "fields": ( - np.reshape( [field for _ in range(np.prod(fig_shape))], - (fig_shape[0], fig_shape[1], 2) ) - ), - "weight_fields": ( - weight_field_ - ), - "zlims": ( - zlim_ - ), - "labels": ( - np.full(fig_shape, None) - ), - "cmaps": ( - np.full(fig_shape, cmap) + np.reshape( + [field for _ in range(np.prod(fig_shape))], + (fig_shape[0], fig_shape[1], 2), + ) ), + "weight_fields": (weight_field_), + "zlims": (zlim_), + "labels": (np.full(fig_shape, None)), + "cmaps": (np.full(fig_shape, cmap)), } # Override defaults with user-supplied params (if any) @@ -378,13 +365,12 @@ def halo_projection_array( fields = params.get("fields", default_params["fields"]) weight_fields = params.get("weight_fields", default_params["weight_fields"]) zlims = params.get("zlims", default_params["zlims"]) - labels= params.get("labels", default_params["labels"]) + labels = params.get("labels", default_params["labels"]) cmaps = params.get("cmaps", default_params["cmaps"]) nrow, ncol = fig_shape ilen, jlen = None, None - # define figure and axes fig, axes, cbars = get_multi_plot(fig_shape[1], fig_shape[0], cbar_padding=0) @@ -395,7 +381,6 @@ def halo_projection_array( for i in range(nrow): for j in range(ncol): - halo_id = halo_ids[i][j] # retrieve halo particle info if new halo @@ -403,35 +388,39 @@ def halo_projection_array( # retrieve properties of halo data_id = data.filter(oc.col("unique_tag") == halo_id) halo_data = next(iter(data_id.objects())) - + # load particles into yt ds = create_yt_dataset(halo_data) - halo_properties = halo_data['halo_properties'] + halo_properties = halo_data["halo_properties"] - Rh = unyt_quantity.from_astropy(halo_properties['sod_halo_radius']) + Rh = unyt_quantity.from_astropy(halo_properties["sod_halo_radius"]) - field, weight_field, zlim = tuple(fields[i][j]), weight_fields[i][j], zlims[i][j] + field, weight_field, zlim = ( + tuple(fields[i][j]), + weight_fields[i][j], + zlims[i][j], + ) if weight_field is not None: - weight_field = tuple(weight_field) + weight_field = tuple(weight_field) # type: ignore if zlim is not None: - zlim = tuple(zlim) + zlim = tuple(zlim) # type: ignore label = labels[i][j] if smooth_gas_fields and field[0] == "gas": - proj = ProjectionPlot(ds,'z',field, weight_field = weight_field) + proj = ProjectionPlot(ds, "z", field, weight_field=weight_field) else: - proj = ParticleProjectionPlot(ds,'z',field, weight_field = weight_field) + proj = ParticleProjectionPlot(ds, "z", field, weight_field=weight_field) - proj.set_background_color(field, color='black') - proj.set_width(width*Rh) + proj.set_background_color(field, color="black") + proj.set_width(width * Rh) - # fetch figure buffer (2D array of pixel values) + # fetch figure buffer (2D array of pixel values) # and re-plot on each panel with imshow frb = proj.frb - + ax = axes[i][j] if zlim is not None: @@ -439,8 +428,12 @@ def halo_projection_array( else: zmin, zmax = None, None - ax.imshow(frb[field], origin="lower", cmap=cmaps[i][j], - norm=LogNorm(vmin=zmin, vmax=zmax)) + ax.imshow( + frb[field], + origin="lower", + cmap=cmaps[i][j], + norm=LogNorm(vmin=zmin, vmax=zmax), + ) ax.set_facecolor("black") ax.xaxis.set_visible(False) @@ -450,13 +443,15 @@ def halo_projection_array( if label is not None: # add panel label ax.text( - 0.06, 0.94, + 0.06, + 0.94, label, transform=ax.transAxes, - ha='left', va='top', + ha="left", + va="top", fontsize=12, - fontfamily='DejaVu Serif', - color = "grey" + fontfamily="DejaVu Serif", + color="grey", ) if length_scale is not None: @@ -464,35 +459,36 @@ def halo_projection_array( case "top left": ilen, jlen = 0, 0 case "top right": - ilen, jlen = 0, ncol-1 + ilen, jlen = 0, ncol - 1 case "bottom left": - ilen, jlen = nrow-1, 0 + ilen, jlen = nrow - 1, 0 case "bottom right": - ilen, jlen = nrow-1, ncol-1 + ilen, jlen = nrow - 1, ncol - 1 case "all left": ilen, jlen = i, 0 case "all right": - ilen, jlen = i, ncol-1 + ilen, jlen = i, ncol - 1 case "all top": ilen, jlen = 0, j case "all bottom": - ilen, jlen = nrow-1, j + ilen, jlen = nrow - 1, j case "all": ilen, jlen = i, j - if (i==ilen and j==jlen): + if i == ilen and j == jlen: # add length scale, assuming # panel is 800 pixels wide scalebar = AnchoredSizeBar( ax.transData, - 800/(width*Rh.d), - '1 Mpc', - 'lower right', - pad=0.4, label_top=False, + 800 / (width * Rh.d), + "1 Mpc", + "lower right", + pad=0.4, + label_top=False, sep=10, - color='grey', + color="grey", frameon=False, - size_vertical=1, + size_vertical=1, ) ax.add_artist(scalebar) From 3c3d69a439a7abee9cb0f4006c934fff62d1f9fd Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 10 Jul 2025 09:39:37 -0500 Subject: [PATCH 2/5] Fix bug associated with conditional loading in lightcones --- opencosmo/analysis/__init__.py | 6 +++++ opencosmo/analysis/diffsky.py | 27 +++++++++++++++++++++ opencosmo/collection/io.py | 1 - opencosmo/collection/lightcone/io.py | 6 ++--- opencosmo/io/io.py | 2 +- test/parallel/test_lc_mpi.py | 3 --- test/test_diffsky.py | 35 ++++++++++++++++++++++------ 7 files changed, 65 insertions(+), 15 deletions(-) create mode 100644 opencosmo/analysis/diffsky.py diff --git a/opencosmo/analysis/__init__.py b/opencosmo/analysis/__init__.py index 067e52ca..b6152b5f 100644 --- a/opencosmo/analysis/__init__.py +++ b/opencosmo/analysis/__init__.py @@ -11,6 +11,7 @@ "halo_projection_array", ] +diffsky_tools = ["get_pop_mah"] try: from .yt_utils import create_yt_dataset @@ -29,6 +30,11 @@ except ImportError: # User has not installed yt tools pass +from .diffsky import get_pop_mah + +__all__.extend(diffsky_tools) + + """ Right now, we have only have two analysis modules so we can handle them directly. In the future we will need to implement a more robust system that handles things automatically. diff --git a/opencosmo/analysis/diffsky.py b/opencosmo/analysis/diffsky.py new file mode 100644 index 00000000..d7fc5fe1 --- /dev/null +++ b/opencosmo/analysis/diffsky.py @@ -0,0 +1,27 @@ +from collections import namedtuple +from typing import NamedTuple, Type, TypeVar + +import numpy as np +from diffmah import mah_halopop + +from opencosmo import Dataset + +DIFFMAH_INPUT = namedtuple( + "DIFFMAH_INPUT", ["logm0", "logtc", "early_index", "late_index", "t_peak"] +) + +T = TypeVar("T", bound=NamedTuple) + + +def make_named_tuple(dataset: Dataset, input_tuple: Type[T]) -> T: + required_columns = input_tuple._fields + data = dataset.select(required_columns).data + output = {c: data[c].value for c in required_columns} + return input_tuple(**output) # type: ignore + + +def get_pop_mah(dataset: Dataset, redshifts: np.ndarray): + mah_params = make_named_tuple(dataset, DIFFMAH_INPUT) + times = dataset.cosmology.age(redshifts).value + + return mah_halopop(mah_params, times, np.log10(dataset.cosmology.age(0).value)) diff --git a/opencosmo/collection/io.py b/opencosmo/collection/io.py index fd12b4b8..f04d4cf7 100644 --- a/opencosmo/collection/io.py +++ b/opencosmo/collection/io.py @@ -43,7 +43,6 @@ def open_collection( """ Open a file with multiple datasets. """ - print("opening collection") CollectionType = get_collection_type(handles) return CollectionType.open(handles, load_kwargs) diff --git a/opencosmo/collection/lightcone/io.py b/opencosmo/collection/lightcone/io.py index c55d8a7c..be2b2af7 100644 --- a/opencosmo/collection/lightcone/io.py +++ b/opencosmo/collection/lightcone/io.py @@ -40,12 +40,12 @@ def open_lightcone(files: list[Path], **load_kwargs): datasets = {} for file in files: - new_ds = oc.open(file) + new_ds = oc.open(file, **load_kwargs) if not isinstance(new_ds, Lightcone): raise ValueError("Didn't find a lightcone in a lightcone file!") for key, ds in new_ds.items(): - key = "_".join([ds.dtype, str(ds.header.file.step)]) - datasets[key] = ds + new_key = "_".join([key, str(ds.header.file.step)]) + datasets[new_key] = ds z_range = headers[0].lightcone.z_range return Lightcone(datasets, z_range) diff --git a/opencosmo/io/io.py b/opencosmo/io/io.py index 755485ed..60f3cf7f 100644 --- a/opencosmo/io/io.py +++ b/opencosmo/io/io.py @@ -84,7 +84,7 @@ def open( paths = [resolve_path(path, FileExistance.MUST_EXIST) for path in files] headers = [read_header(p) for p in paths] if all(h.file.is_lightcone for h in headers): - return oc.collection.open_lightcone(paths) + return oc.collection.open_lightcone(paths, **load_kwargs) return oc.open_linked_files(*paths, **load_kwargs) diff --git a/test/parallel/test_lc_mpi.py b/test/parallel/test_lc_mpi.py index e774a855..433d853c 100644 --- a/test/parallel/test_lc_mpi.py +++ b/test/parallel/test_lc_mpi.py @@ -109,13 +109,10 @@ def test_lc_collection_write( ): path = MPI.COMM_WORLD.bcast(tmp_path) ds = oc.open(haloproperties_601_path, haloproperties_600_path) - print(len(ds)) ds = ds.with_redshift_range(0.039, 0.0405) original_length = len(ds) oc.write(path / "lightcone.hdf5", ds) ds = oc.open(path / "lightcone.hdf5") - print(original_length) - print(len(ds)) data = ds.select("redshift").data parallel_assert(data.min() >= 0.039 and data.max() <= 0.0405) parallel_assert(len(data) == original_length) diff --git a/test/test_diffsky.py b/test/test_diffsky.py index bcbedf79..11c4cbae 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -1,30 +1,51 @@ +import matplotlib.pyplot as plt import numpy as np import pytest import opencosmo as oc +from opencosmo.analysis import get_pop_mah @pytest.fixture -def core_path(diffsky_path): +def core_path_487(diffsky_path): return diffsky_path / "lj_487.hdf5" -def test_comoving_to_physical(core_path): - cores = oc.open(core_path, synth_cores=True).select(["redshift_true", "x"]) +@pytest.fixture +def core_path_475(diffsky_path): + return diffsky_path / "lj_475.hdf5" + + +def test_comoving_to_physical(core_path_487): + cores = oc.open(core_path_487, synth_cores=True).select(["redshift_true", "x"]) data_physical = cores.with_units("physical").select(["redshift_true", "x"]).data data_comoving = cores.select(["redshift_true", "x"]).data a = 1 / (data_physical["redshift_true"] + 1) assert np.all(np.isclose(data_physical["x"], data_comoving["x"] * a)) -def test_comoving_to_scalefree(core_path): +def test_comoving_to_scalefree(core_path_487): with pytest.raises(oc.transformations.units.UnitError): - _ = oc.open(core_path, synth_cores=True).with_units("scalefree") + _ = oc.open(core_path_487, synth_cores=True).with_units("scalefree") -def test_comoving_to_unitless(core_path): - ds = oc.open(core_path, synth_cores=True) +def test_comoving_to_unitless(core_path_487): + ds = oc.open(core_path_487, synth_cores=True) data = ds.data data_unitless = ds.with_units("unitless").data for col in data.columns: assert np.all(data[col].value == data_unitless[col].value) + + +def test_mah_pop(core_path_487): + zs = np.linspace(0, 1, 10) + ds = oc.open(core_path_487, synth_cores=False) + c = get_pop_mah(ds, zs) + plt.plot(c[0], c[1]) + plt.savefig("test.png") + assert False + + +def test_open_multiple(core_path_487, core_path_475): + ds = oc.open(core_path_487, core_path_475, synth_cores=True) + assert len(ds.keys()) == 4 From ebcaadbf8f130c094554e1b0ce62cb135c676c12 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 10 Jul 2025 10:42:01 -0500 Subject: [PATCH 3/5] Concatenate lightcone ranges from multiple files --- opencosmo/collection/lightcone/io.py | 7 ++++-- opencosmo/dataset/dataset.py | 4 +++- opencosmo/spatial/builders.py | 2 +- opencosmo/spatial/check.py | 34 ++++++++++++++++------------ test/test_diffsky.py | 13 +++++++++++ 5 files changed, 41 insertions(+), 19 deletions(-) diff --git a/opencosmo/collection/lightcone/io.py b/opencosmo/collection/lightcone/io.py index be2b2af7..6a994032 100644 --- a/opencosmo/collection/lightcone/io.py +++ b/opencosmo/collection/lightcone/io.py @@ -47,8 +47,11 @@ def open_lightcone(files: list[Path], **load_kwargs): new_key = "_".join([key, str(ds.header.file.step)]) datasets[new_key] = ds - z_range = headers[0].lightcone.z_range - return Lightcone(datasets, z_range) + z_ranges = [h.lightcone.z_range for h in headers] + z_min = min(z[0] for z in z_ranges) + z_max = max(z[1] for z in z_ranges) + + return Lightcone(datasets, (z_min, z_max)) def open_lightcone_file(path: Path) -> dict[str, Dataset]: diff --git a/opencosmo/dataset/dataset.py b/opencosmo/dataset/dataset.py index c1ab2f24..db0ca9ff 100644 --- a/opencosmo/dataset/dataset.py +++ b/opencosmo/dataset/dataset.py @@ -240,7 +240,9 @@ def bound(self, region: Region, select_by: Optional[str] = None): self.__header, check_state, self.__tree, - ).with_units("scalefree") + ) + if not self.__header.file.is_lightcone: + check_dataset = check_dataset.with_units("scalefree") mask = check.check_containment(check_dataset, check_region, self.__header.file) new_intersects_index = intersects_index.mask(mask) diff --git a/opencosmo/spatial/builders.py b/opencosmo/spatial/builders.py index bc02d84a..fbb99f7a 100644 --- a/opencosmo/spatial/builders.py +++ b/opencosmo/spatial/builders.py @@ -106,6 +106,6 @@ def make_cone(center: Point2d | SkyCoord, radius: float | u.Quantity): coord = SkyCoord(*center) case _: raise ValueError("Invalid center for Cone region") - if isinstance(radius, float): + if isinstance(radius, (float, int)): radius = radius * u.deg return ConeRegion(coord, radius) diff --git a/opencosmo/spatial/check.py b/opencosmo/spatial/check.py index f57ef945..c1312806 100644 --- a/opencosmo/spatial/check.py +++ b/opencosmo/spatial/check.py @@ -35,6 +35,24 @@ def check_containment( return __check_containment_3d(ds, region, dtype) +def get_theta_phi_coordinates(dataset: "Dataset"): + coord_values = dataset.select(["theta", "phi"]).data + ra = coord_values["phi"] + dec = np.pi / 2 - coord_values["theta"] + + return SkyCoord(ra, dec, unit=u.rad) + + +def find_coordinates_2d(dataset: "Dataset"): + columns = set(dataset.columns) + if len(columns.intersection(set(["theta", "phi"]))) == 2: + return get_theta_phi_coordinates(dataset) + elif len(columns.intersection(set(["ra", "dec"]))) == 2: + data = dataset.select(["ra", "dec"]).data + return SkyCoord(data["ra"], data["dec"]) + raise ValueError("Dataset does not contain coordinates") + + def __check_containment_3d( ds: "Dataset", region: "Region", dtype: str, select_by: Optional[str] = None ): @@ -65,19 +83,5 @@ def __check_containment_3d( def __check_containment_2d( ds: "Dataset", region: "Region", dtype: str, select_by: Optional[str] = None ): - try: - allowed_coordinates = ALLOWED_COORDINATES_2D[dtype] - except KeyError: - allowed_coordinates = ALLOWED_COORDINATES_2D["default"] - cols = set(ds.columns) - if cols.intersection(allowed_coordinates) != allowed_coordinates: - raise ValueError( - "Unable to find the correct coordinate columns in this dataset!" - ) - - coord_values = ds.select(allowed_coordinates).data - ra = coord_values["phi"] - dec = np.pi / 2 - coord_values["theta"] - - coords = SkyCoord(ra, dec, unit=u.rad) + coords = find_coordinates_2d(ds) return region.contains(coords) diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 11c4cbae..deb7c539 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -49,3 +49,16 @@ def test_mah_pop(core_path_487): def test_open_multiple(core_path_487, core_path_475): ds = oc.open(core_path_487, core_path_475, synth_cores=True) assert len(ds.keys()) == 4 + ds = oc.open(core_path_487, core_path_475) + assert len(ds.keys()) == 2 + z_range = ds.z_range + assert z_range[1] - z_range[0] > 0.05 + + +def test_cone_search(core_path_475, core_path_487): + center = (40, 67) + radius = 2 + ds = oc.open(core_path_487, core_path_475, synth_cores=True) + region = oc.make_cone(center, radius) + ds = ds.bound(region) + assert False From 8d40b63329339aa79dafa530c3ecbce4600139b3 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 10 Jul 2025 13:16:58 -0500 Subject: [PATCH 4/5] Some tests --- opencosmo/collection/lightcone/lightcone.py | 35 +++++++++++++++++++-- test/test_diffsky.py | 12 ++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/opencosmo/collection/lightcone/lightcone.py b/opencosmo/collection/lightcone/lightcone.py index 85643d62..8d50c9ba 100644 --- a/opencosmo/collection/lightcone/lightcone.py +++ b/opencosmo/collection/lightcone/lightcone.py @@ -5,6 +5,7 @@ import astropy.units as u # type: ignore import h5py import numpy as np +from astropy.coordinates import SkyCoord from astropy.cosmology import Cosmology # type: ignore from astropy.table import vstack # type: ignore @@ -31,8 +32,7 @@ def get_redshift_range(datasets: list[Dataset]): def is_in_range(dataset: Dataset, z_low: float, z_high: float): - step_zs = dataset.header.simulation.step_zs - z_range = (step_zs[dataset.header.file.step], step_zs[dataset.header.file.step - 1]) + z_range = dataset.header.lightcone.z_range if z_high < z_range[0] or z_low > z_range[1]: return False return True @@ -347,6 +347,37 @@ def bound(self, region: Region, select_by: Optional[str] = None): """ return self.__map("bound", region, select_by) + def cone_search(self, center: tuple | SkyCoord, radius: float | u.Quantity): + """ + Perform a search for objects within some angular distance of some + given point on the sky. This is a convinience function around + :py:meth`bound ` which is exactly + equivalent to + + .. code-block:: python + + region = oc.make_cone(center, radius) + ds = ds.bound(region) + + Parameters + ---------- + center: tuple | SkyCoord + The center of the region to search. If a tuple and no units are provided + assumed to be RA and Dec in degrees. + + radius: float | astropy.units.Quantity + The angular radius of the region to query. If no units are provided, + assumed to be degrees. + + Returns + ------- + new_lightcone: opencosmo.Lightcone + The rows in this lightcone that fall within the given region. + + """ + region = oc.make_cone(center, radius) + return self.bound(region) + def filter(self, *masks: Mask, **kwargs) -> Self: """ Filter the dataset based on some criteria. See :ref:`Querying Based on Column diff --git a/test/test_diffsky.py b/test/test_diffsky.py index deb7c539..09cbce1e 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -55,10 +55,20 @@ def test_open_multiple(core_path_487, core_path_475): assert z_range[1] - z_range[0] > 0.05 +def test_open_multiple_write(core_path_487, core_path_475, tmp_path): + ds = oc.open(core_path_487, core_path_475, synth_cores=True) + original_length = len(ds) + original_redshift_range = ds.z_range + output = tmp_path / "synth_gals.hdf5" + oc.write(output, ds) + ds = oc.open(output) + assert len(ds) == original_length + assert ds.z_range == original_redshift_range + + def test_cone_search(core_path_475, core_path_487): center = (40, 67) radius = 2 ds = oc.open(core_path_487, core_path_475, synth_cores=True) region = oc.make_cone(center, radius) ds = ds.bound(region) - assert False From 80df149a5e7c467d03cccdd502ca0dfef9ee24b5 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Thu, 10 Jul 2025 13:32:33 -0500 Subject: [PATCH 5/5] Fix some tests --- opencosmo/analysis/__init__.py | 5 ----- opencosmo/analysis/diffsky.py | 10 ---------- opencosmo/collection/lightcone/io.py | 10 +++++++--- opencosmo/collection/lightcone/lightcone.py | 4 +++- test/test_diffsky.py | 11 ----------- 5 files changed, 10 insertions(+), 30 deletions(-) diff --git a/opencosmo/analysis/__init__.py b/opencosmo/analysis/__init__.py index b6152b5f..ef76e930 100644 --- a/opencosmo/analysis/__init__.py +++ b/opencosmo/analysis/__init__.py @@ -11,7 +11,6 @@ "halo_projection_array", ] -diffsky_tools = ["get_pop_mah"] try: from .yt_utils import create_yt_dataset @@ -30,10 +29,6 @@ except ImportError: # User has not installed yt tools pass -from .diffsky import get_pop_mah - -__all__.extend(diffsky_tools) - """ Right now, we have only have two analysis modules so we can handle them directly. In the diff --git a/opencosmo/analysis/diffsky.py b/opencosmo/analysis/diffsky.py index d7fc5fe1..98b78105 100644 --- a/opencosmo/analysis/diffsky.py +++ b/opencosmo/analysis/diffsky.py @@ -1,9 +1,6 @@ from collections import namedtuple from typing import NamedTuple, Type, TypeVar -import numpy as np -from diffmah import mah_halopop - from opencosmo import Dataset DIFFMAH_INPUT = namedtuple( @@ -18,10 +15,3 @@ def make_named_tuple(dataset: Dataset, input_tuple: Type[T]) -> T: data = dataset.select(required_columns).data output = {c: data[c].value for c in required_columns} return input_tuple(**output) # type: ignore - - -def get_pop_mah(dataset: Dataset, redshifts: np.ndarray): - mah_params = make_named_tuple(dataset, DIFFMAH_INPUT) - times = dataset.cosmology.age(redshifts).value - - return mah_halopop(mah_params, times, np.log10(dataset.cosmology.age(0).value)) diff --git a/opencosmo/collection/lightcone/io.py b/opencosmo/collection/lightcone/io.py index 6a994032..8da3ad04 100644 --- a/opencosmo/collection/lightcone/io.py +++ b/opencosmo/collection/lightcone/io.py @@ -48,10 +48,14 @@ def open_lightcone(files: list[Path], **load_kwargs): datasets[new_key] = ds z_ranges = [h.lightcone.z_range for h in headers] - z_min = min(z[0] for z in z_ranges) - z_max = max(z[1] for z in z_ranges) + if z_ranges[0] is not None: + z_min = min(z[0] for z in z_ranges) + z_max = max(z[1] for z in z_ranges) + z_range = (z_min, z_max) + else: + z_range = None - return Lightcone(datasets, (z_min, z_max)) + return Lightcone(datasets, z_range) def open_lightcone_file(path: Path) -> dict[str, Dataset]: diff --git a/opencosmo/collection/lightcone/lightcone.py b/opencosmo/collection/lightcone/lightcone.py index 8d50c9ba..7d2a3577 100644 --- a/opencosmo/collection/lightcone/lightcone.py +++ b/opencosmo/collection/lightcone/lightcone.py @@ -5,7 +5,7 @@ import astropy.units as u # type: ignore import h5py import numpy as np -from astropy.coordinates import SkyCoord +from astropy.coordinates import SkyCoord # type: ignore from astropy.cosmology import Cosmology # type: ignore from astropy.table import vstack # type: ignore @@ -33,6 +33,8 @@ def get_redshift_range(datasets: list[Dataset]): def is_in_range(dataset: Dataset, z_low: float, z_high: float): z_range = dataset.header.lightcone.z_range + if z_range is None: + z_range = get_redshift_range([dataset]) if z_high < z_range[0] or z_low > z_range[1]: return False return True diff --git a/test/test_diffsky.py b/test/test_diffsky.py index 09cbce1e..099536d1 100644 --- a/test/test_diffsky.py +++ b/test/test_diffsky.py @@ -1,9 +1,7 @@ -import matplotlib.pyplot as plt import numpy as np import pytest import opencosmo as oc -from opencosmo.analysis import get_pop_mah @pytest.fixture @@ -37,15 +35,6 @@ def test_comoving_to_unitless(core_path_487): assert np.all(data[col].value == data_unitless[col].value) -def test_mah_pop(core_path_487): - zs = np.linspace(0, 1, 10) - ds = oc.open(core_path_487, synth_cores=False) - c = get_pop_mah(ds, zs) - plt.plot(c[0], c[1]) - plt.savefig("test.png") - assert False - - def test_open_multiple(core_path_487, core_path_475): ds = oc.open(core_path_487, core_path_475, synth_cores=True) assert len(ds.keys()) == 4