1
1
import glob
2
2
import os
3
3
from functools import cached_property
4
- from typing import List , Optional
4
+ from typing import Any , List , Optional
5
5
6
6
import numpy as np
7
7
11
11
from yt .geometry .particle_geometry_handler import ParticleIndex
12
12
from yt .utilities import fortran_utils as fpu
13
13
from yt .utilities .cosmology import Cosmology
14
+ from yt .utilities .exceptions import YTFieldNotFound
14
15
15
16
from .definitions import header_dt
16
17
from .fields import RockstarFieldInfo
@@ -20,7 +21,7 @@ class RockstarBinaryFile(HaloCatalogFile):
20
21
header : dict
21
22
_position_offset : int
22
23
_member_offset : int
23
- _Npart : np .array
24
+ _Npart : " np.ndarray[Any, np.dtype[np.int64]]"
24
25
_ids_halos : List [int ]
25
26
_file_size : int
26
27
@@ -46,7 +47,9 @@ def __init__(self, ds, io, filename, file_id, range):
46
47
47
48
super ().__init__ (ds , io , filename , file_id , range )
48
49
49
- def _read_member (self , ihalo : int ) -> Optional [np .array ]:
50
+ def _read_member (
51
+ self , ihalo : int
52
+ ) -> Optional ["np.ndarray[Any, np.dtype[np.int64]]" ]:
50
53
if ihalo not in self ._ids_halos :
51
54
return None
52
55
@@ -59,7 +62,7 @@ def _read_member(self, ihalo: int) -> Optional[np.array]:
59
62
ids = np .fromfile (f , dtype = np .int64 , count = self ._Npart [ind_halo ])
60
63
return ids
61
64
62
- def _read_particle_positions (self , ptype , f = None ):
65
+ def _read_particle_positions (self , ptype : str , f = None ):
63
66
"""
64
67
Read all particle positions in this file.
65
68
"""
@@ -166,32 +169,47 @@ def _is_valid(cls, filename, *args, **kwargs):
166
169
return True
167
170
return False
168
171
169
- def halo (self , halo_id , ptype = "DM" ):
172
+ def halo (self , ptype , particle_identifier ):
170
173
return RockstarHaloContainer (
171
- halo_id ,
172
174
ptype ,
175
+ particle_identifier ,
173
176
parent_ds = None ,
174
177
halo_ds = self ,
175
178
)
176
179
177
180
178
181
class RockstarHaloContainer :
179
- def __init__ (self , ptype , particle_identifier , parent_ds , halo_ds ):
180
- # if ptype not in parent_ds .particle_types_raw:
181
- # raise RuntimeError(
182
- # f'Possible halo types are {parent_ds .particle_types_raw}, supplied "{ptype}".'
183
- # )
182
+ def __init__ (self , ptype , particle_identifier , * , parent_ds , halo_ds ):
183
+ if ptype not in halo_ds .particle_types_raw :
184
+ raise RuntimeError (
185
+ f'Possible halo types are { halo_ds .particle_types_raw } , supplied "{ ptype } ".'
186
+ )
184
187
185
188
self .ds = parent_ds
186
189
self .halo_ds = halo_ds
187
190
self .ptype = ptype
188
191
self .particle_identifier = particle_identifier
189
192
190
193
def __repr__ (self ):
191
- return "%s_%s_%09d" % (self .ds , self .ptype , self .particle_identifier )
194
+ return "%s_%s_%09d" % (self .halo_ds , self .ptype , self .particle_identifier )
192
195
193
196
def __getitem__ (self , key ):
194
- return self .region [key ]
197
+ if isinstance (key , tuple ):
198
+ ptype , field = key
199
+ else :
200
+ ptype = self .ptype
201
+ field = key
202
+
203
+ data = {
204
+ "mass" : self .mass ,
205
+ "position" : self .position ,
206
+ "velocity" : self .velocity ,
207
+ "member_ids" : self .member_ids ,
208
+ }
209
+ if ptype == "halos" and field in data :
210
+ return data [field ]
211
+
212
+ raise YTFieldNotFound ((ptype , field ), dataset = self .ds )
195
213
196
214
@cached_property
197
215
def ihalo (self ):
0 commit comments