1
1
import glob
2
2
import os
3
3
from functools import cached_property
4
- from typing import List , Optional
4
+ from typing import Any , List , Optional
5
5
6
6
import numpy as np
7
7
11
11
from yt .geometry .particle_geometry_handler import ParticleIndex
12
12
from yt .utilities import fortran_utils as fpu
13
13
from yt .utilities .cosmology import Cosmology
14
+ from yt .utilities .exceptions import YTFieldNotFound
14
15
15
16
from .definitions import header_dt
16
17
from .fields import RockstarFieldInfo
@@ -20,7 +21,7 @@ class RockstarBinaryFile(HaloCatalogFile):
20
21
header : dict
21
22
_position_offset : int
22
23
_member_offset : int
23
- _Npart : np .array
24
+ _Npart : " np.ndarray[Any, np.dtype[np.int64]]"
24
25
_ids_halos : List [int ]
25
26
_file_size : int
26
27
@@ -46,7 +47,9 @@ def __init__(self, ds, io, filename, file_id, range):
46
47
47
48
super ().__init__ (ds , io , filename , file_id , range )
48
49
49
- def _read_member (self , ihalo : int ) -> Optional [np .array ]:
50
+ def _read_member (
51
+ self , ihalo : int
52
+ ) -> Optional ["np.ndarray[Any, np.dtype[np.int64]]" ]:
50
53
if ihalo not in self ._ids_halos :
51
54
return None
52
55
@@ -59,7 +62,7 @@ def _read_member(self, ihalo: int) -> Optional[np.array]:
59
62
ids = np .fromfile (f , dtype = np .int64 , count = self ._Npart [ind_halo ])
60
63
return ids
61
64
62
- def _read_particle_positions (self , ptype , f = None ):
65
+ def _read_particle_positions (self , ptype : str , f = None ):
63
66
"""
64
67
Read all particle positions in this file.
65
68
"""
@@ -168,32 +171,47 @@ def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
168
171
else :
169
172
return header ["magic" ] == 18077126535843729616
170
173
171
- def halo (self , halo_id , ptype = "DM" ):
174
+ def halo (self , ptype , particle_identifier ):
172
175
return RockstarHaloContainer (
173
- halo_id ,
174
176
ptype ,
177
+ particle_identifier ,
175
178
parent_ds = None ,
176
179
halo_ds = self ,
177
180
)
178
181
179
182
180
183
class RockstarHaloContainer :
181
- def __init__ (self , ptype , particle_identifier , parent_ds , halo_ds ):
182
- # if ptype not in parent_ds .particle_types_raw:
183
- # raise RuntimeError(
184
- # f'Possible halo types are {parent_ds .particle_types_raw}, supplied "{ptype}".'
185
- # )
184
+ def __init__ (self , ptype , particle_identifier , * , parent_ds , halo_ds ):
185
+ if ptype not in halo_ds .particle_types_raw :
186
+ raise RuntimeError (
187
+ f'Possible halo types are { halo_ds .particle_types_raw } , supplied "{ ptype } ".'
188
+ )
186
189
187
190
self .ds = parent_ds
188
191
self .halo_ds = halo_ds
189
192
self .ptype = ptype
190
193
self .particle_identifier = particle_identifier
191
194
192
195
def __repr__ (self ):
193
- return "%s_%s_%09d" % (self .ds , self .ptype , self .particle_identifier )
196
+ return "%s_%s_%09d" % (self .halo_ds , self .ptype , self .particle_identifier )
194
197
195
198
def __getitem__ (self , key ):
196
- return self .region [key ]
199
+ if isinstance (key , tuple ):
200
+ ptype , field = key
201
+ else :
202
+ ptype = self .ptype
203
+ field = key
204
+
205
+ data = {
206
+ "mass" : self .mass ,
207
+ "position" : self .position ,
208
+ "velocity" : self .velocity ,
209
+ "member_ids" : self .member_ids ,
210
+ }
211
+ if ptype == "halos" and field in data :
212
+ return data [field ]
213
+
214
+ raise YTFieldNotFound ((ptype , field ), dataset = self .ds )
197
215
198
216
@cached_property
199
217
def ihalo (self ):
0 commit comments