Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions opencosmo/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from .yt_utils import create_yt_dataset
from .yt_viz import (
ProjectionPlot, SlicePlot, ParticleProjectionPlot,
ProfilePlot, PhasePlot,
visualize_halo, halo_projection_array,
)

__all__ = [
# ruff: noqa
__all__ = []
yt_tools = [
"create_yt_dataset",
"ProjectionPlot",
"SlicePlot",
Expand All @@ -15,3 +10,37 @@
"visualize_halo",
"halo_projection_array",
]


try:
from .yt_utils import create_yt_dataset
from .yt_viz import (
ParticleProjectionPlot,
PhasePlot,
ProfilePlot,
ProjectionPlot,
SlicePlot,
halo_projection_array,
visualize_halo,
)

__all__.extend(yt_tools)

except ImportError: # User has not installed yt tools
pass


"""
Right now, we have only have two analysis modules so we can handle them directly. In the
future we will need to implement a more robust system that handles things automatically.
"""


def __getattr__(name):
if name in yt_tools:
raise ImportError(
"You tried to import one of the OpenCosmo YT tools, but your python "
"environment does not have the necessary dependencies. You can do install "
"them with `pip install opencosmo[analysis]`"
)
raise ImportError(f"Cannot import name '{name}' from opencosmo.analysis")
17 changes: 17 additions & 0 deletions opencosmo/analysis/diffsky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from collections import namedtuple
from typing import NamedTuple, Type, TypeVar

from opencosmo import Dataset

DIFFMAH_INPUT = namedtuple(
"DIFFMAH_INPUT", ["logm0", "logtc", "early_index", "late_index", "t_peak"]
)

T = TypeVar("T", bound=NamedTuple)


def make_named_tuple(dataset: Dataset, input_tuple: Type[T]) -> T:
required_columns = input_tuple._fields
data = dataset.select(required_columns).data
output = {c: data[c].value for c in required_columns}
return input_tuple(**output) # type: ignore
Loading
Loading