Skip to content

Commit dc614e1

Browse files
committed
Take review comments into account
1 parent 9384474 commit dc614e1

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

yt/frontends/rockstar/data_structures.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import glob
22
import os
33
from functools import cached_property
4-
from typing import List, Optional
4+
from typing import Any, List, Optional
55

66
import numpy as np
77

@@ -11,6 +11,7 @@
1111
from yt.geometry.particle_geometry_handler import ParticleIndex
1212
from yt.utilities import fortran_utils as fpu
1313
from yt.utilities.cosmology import Cosmology
14+
from yt.utilities.exceptions import YTFieldNotFound
1415

1516
from .definitions import header_dt
1617
from .fields import RockstarFieldInfo
@@ -20,7 +21,7 @@ class RockstarBinaryFile(HaloCatalogFile):
2021
header: dict
2122
_position_offset: int
2223
_member_offset: int
23-
_Npart: np.array
24+
_Npart: "np.ndarray[Any, np.dtype[np.int64]]"
2425
_ids_halos: List[int]
2526
_file_size: int
2627

@@ -46,7 +47,9 @@ def __init__(self, ds, io, filename, file_id, range):
4647

4748
super().__init__(ds, io, filename, file_id, range)
4849

49-
def _read_member(self, ihalo: int) -> Optional[np.array]:
50+
def _read_member(
51+
self, ihalo: int
52+
) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]:
5053
if ihalo not in self._ids_halos:
5154
return None
5255

@@ -59,7 +62,7 @@ def _read_member(self, ihalo: int) -> Optional[np.array]:
5962
ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo])
6063
return ids
6164

62-
def _read_particle_positions(self, ptype, f=None):
65+
def _read_particle_positions(self, ptype: str, f=None):
6366
"""
6467
Read all particle positions in this file.
6568
"""
@@ -166,32 +169,47 @@ def _is_valid(cls, filename, *args, **kwargs):
166169
return True
167170
return False
168171

169-
def halo(self, halo_id, ptype="DM"):
172+
def halo(self, ptype, particle_identifier):
170173
return RockstarHaloContainer(
171-
halo_id,
172174
ptype,
175+
particle_identifier,
173176
parent_ds=None,
174177
halo_ds=self,
175178
)
176179

177180

178181
class RockstarHaloContainer:
179-
def __init__(self, ptype, particle_identifier, parent_ds, halo_ds):
180-
# if ptype not in parent_ds.particle_types_raw:
181-
# raise RuntimeError(
182-
# f'Possible halo types are {parent_ds.particle_types_raw}, supplied "{ptype}".'
183-
# )
182+
def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds):
183+
if ptype not in halo_ds.particle_types_raw:
184+
raise RuntimeError(
185+
f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".'
186+
)
184187

185188
self.ds = parent_ds
186189
self.halo_ds = halo_ds
187190
self.ptype = ptype
188191
self.particle_identifier = particle_identifier
189192

190193
def __repr__(self):
191-
return "%s_%s_%09d" % (self.ds, self.ptype, self.particle_identifier)
194+
return "%s_%s_%09d" % (self.halo_ds, self.ptype, self.particle_identifier)
192195

193196
def __getitem__(self, key):
194-
return self.region[key]
197+
if isinstance(key, tuple):
198+
ptype, field = key
199+
else:
200+
ptype = self.ptype
201+
field = key
202+
203+
data = {
204+
"mass": self.mass,
205+
"position": self.position,
206+
"velocity": self.velocity,
207+
"member_ids": self.member_ids,
208+
}
209+
if ptype == "halos" and field in data:
210+
return data[field]
211+
212+
raise YTFieldNotFound((ptype, field), dataset=self.ds)
195213

196214
@cached_property
197215
def ihalo(self):

0 commit comments

Comments
 (0)