Skip to content

Commit 5e70326

Browse files
committed
Take review comments into account
1 parent fd040d9 commit 5e70326

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

yt/frontends/rockstar/data_structures.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import glob
22
import os
33
from functools import cached_property
4-
from typing import List, Optional
4+
from typing import Any, List, Optional
55

66
import numpy as np
77

@@ -11,6 +11,7 @@
1111
from yt.geometry.particle_geometry_handler import ParticleIndex
1212
from yt.utilities import fortran_utils as fpu
1313
from yt.utilities.cosmology import Cosmology
14+
from yt.utilities.exceptions import YTFieldNotFound
1415

1516
from .definitions import header_dt
1617
from .fields import RockstarFieldInfo
@@ -20,7 +21,7 @@ class RockstarBinaryFile(HaloCatalogFile):
2021
header: dict
2122
_position_offset: int
2223
_member_offset: int
23-
_Npart: np.array
24+
_Npart: "np.ndarray[Any, np.dtype[np.int64]]"
2425
_ids_halos: List[int]
2526
_file_size: int
2627

@@ -46,7 +47,9 @@ def __init__(self, ds, io, filename, file_id, range):
4647

4748
super().__init__(ds, io, filename, file_id, range)
4849

49-
def _read_member(self, ihalo: int) -> Optional[np.array]:
50+
def _read_member(
51+
self, ihalo: int
52+
) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]:
5053
if ihalo not in self._ids_halos:
5154
return None
5255

@@ -59,7 +62,7 @@ def _read_member(self, ihalo: int) -> Optional[np.array]:
5962
ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo])
6063
return ids
6164

62-
def _read_particle_positions(self, ptype, f=None):
65+
def _read_particle_positions(self, ptype: str, f=None):
6366
"""
6467
Read all particle positions in this file.
6568
"""
@@ -168,32 +171,47 @@ def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
168171
else:
169172
return header["magic"] == 18077126535843729616
170173

171-
def halo(self, halo_id, ptype="DM"):
174+
def halo(self, ptype, particle_identifier):
172175
return RockstarHaloContainer(
173-
halo_id,
174176
ptype,
177+
particle_identifier,
175178
parent_ds=None,
176179
halo_ds=self,
177180
)
178181

179182

180183
class RockstarHaloContainer:
181-
def __init__(self, ptype, particle_identifier, parent_ds, halo_ds):
182-
# if ptype not in parent_ds.particle_types_raw:
183-
# raise RuntimeError(
184-
# f'Possible halo types are {parent_ds.particle_types_raw}, supplied "{ptype}".'
185-
# )
184+
def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds):
185+
if ptype not in halo_ds.particle_types_raw:
186+
raise RuntimeError(
187+
f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".'
188+
)
186189

187190
self.ds = parent_ds
188191
self.halo_ds = halo_ds
189192
self.ptype = ptype
190193
self.particle_identifier = particle_identifier
191194

192195
def __repr__(self):
193-
return "%s_%s_%09d" % (self.ds, self.ptype, self.particle_identifier)
196+
return "%s_%s_%09d" % (self.halo_ds, self.ptype, self.particle_identifier)
194197

195198
def __getitem__(self, key):
196-
return self.region[key]
199+
if isinstance(key, tuple):
200+
ptype, field = key
201+
else:
202+
ptype = self.ptype
203+
field = key
204+
205+
data = {
206+
"mass": self.mass,
207+
"position": self.position,
208+
"velocity": self.velocity,
209+
"member_ids": self.member_ids,
210+
}
211+
if ptype == "halos" and field in data:
212+
return data[field]
213+
214+
raise YTFieldNotFound((ptype, field), dataset=self.ds)
197215

198216
@cached_property
199217
def ihalo(self):

0 commit comments

Comments
 (0)