Skip to content

Commit f1e1855

Browse files
authored
Merge pull request #5159 from cphyc/bugfix/order-of-operation-matters
[BUGFIX] RAMSES: Use one registry per dataset
2 parents 2d7d250 + 7f8a0a3 commit f1e1855

File tree

2 files changed

+49
-17
lines changed

2 files changed

+49
-17
lines changed

yt/frontends/ramses/field_handlers.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def setup_handler(self, domain):
4242
self.ds = ds = domain.ds
4343
self.domain = domain
4444
self.domain_id = domain.domain_id
45+
4546
basename = os.path.abspath(ds.root_folder)
4647
iout = int(os.path.basename(ds.parameter_filename).split(".")[0].split("_")[1])
4748

@@ -166,6 +167,9 @@ def __init_subclass__(cls, *args, **kwargs):
166167
register_field_handler(cls)
167168

168169
cls._unique_registry = {}
170+
cls.parameters = {}
171+
cls.rt_parameters = {}
172+
cls._detected_field_list = {}
169173
return cls
170174

171175
def __init__(self, domain):
@@ -231,6 +235,10 @@ def level_count(self):
231235

232236
return self._level_count
233237

238+
@property
239+
def field_list(self):
240+
return self._detected_field_list[self.ds.unique_identifier]
241+
234242
@cached_property
235243
def offset(self):
236244
"""
@@ -242,7 +250,7 @@ def offset(self):
242250
It should be generic enough for most of the cases, but if the
243251
*structure* of your fluid file is non-canonical, change this.
244252
"""
245-
nvars = len(self.field_list)
253+
nvars = len(self._detected_field_list[self.ds.unique_identifier])
246254
with FortranFile(self.fname) as fd:
247255
# Skip headers
248256
nskip = len(self.attrs)
@@ -265,7 +273,7 @@ def offset(self):
265273
fd,
266274
min_level,
267275
self.domain.domain_id,
268-
self.parameters["nvar"],
276+
self.parameters[self.ds.unique_identifier]["nvar"],
269277
self.domain.amr_header,
270278
Nskip=nvars * 8,
271279
)
@@ -314,7 +322,7 @@ def detect_fields(cls, ds):
314322
attrs = cls.attrs
315323
with FortranFile(fname) as fd:
316324
hvals = fd.read_attrs(attrs)
317-
cls.parameters = hvals
325+
cls.parameters[ds.unique_identifier] = hvals
318326

319327
# Store some metadata
320328
ds.gamma = hvals["gamma"]
@@ -445,7 +453,9 @@ def detect_fields(cls, ds):
445453
count_extra += 1
446454
if count_extra > 0:
447455
mylog.debug("Detected %s extra fluid fields.", count_extra)
448-
cls.field_list = [(cls.ftype, e) for e in fields]
456+
cls._detected_field_list[ds.unique_identifier] = [
457+
(cls.ftype, e) for e in fields
458+
]
449459

450460
cls.set_detected_fields(ds, fields)
451461

@@ -476,9 +486,9 @@ def detect_fields(cls, ds):
476486
basedir = os.path.split(ds.parameter_filename)[0]
477487
fname = os.path.join(basedir, cls.fname.format(iout=iout, icpu=1))
478488
with FortranFile(fname) as fd:
479-
cls.parameters = fd.read_attrs(cls.attrs)
489+
cls.parameters[ds.unique_identifier] = fd.read_attrs(cls.attrs)
480490

481-
nvar = cls.parameters["nvar"]
491+
nvar = cls.parameters[ds.unique_identifier]["nvar"]
482492
ndim = ds.dimensionality
483493

484494
fields = cls.load_fields_from_yt_config()
@@ -497,7 +507,9 @@ def detect_fields(cls, ds):
497507
for i in range(nvar - ndetected):
498508
fields.append(f"var{i}")
499509

500-
cls.field_list = [(cls.ftype, e) for e in fields]
510+
cls._detected_field_list[ds.unique_identifier] = [
511+
(cls.ftype, e) for e in fields
512+
]
501513

502514
cls.set_detected_fields(ds, fields)
503515

@@ -572,7 +584,7 @@ def read_rhs(cast):
572584
# Touchy part, we have to read the photon group properties
573585
mylog.debug("Not reading photon group properties")
574586

575-
cls.rt_parameters = rheader
587+
cls.rt_parameters[ds.unique_identifier] = rheader
576588

577589
ngroups = rheader["nGroups"]
578590

@@ -581,7 +593,7 @@ def read_rhs(cast):
581593
fname = os.path.join(basedir, cls.fname.format(iout=iout, icpu=1))
582594
fname_desc = os.path.join(basedir, cls.file_descriptor)
583595
with FortranFile(fname) as fd:
584-
cls.parameters = fd.read_attrs(cls.attrs)
596+
cls.parameters[ds.unique_identifier] = fd.read_attrs(cls.attrs)
585597

586598
ok = False
587599

@@ -615,16 +627,18 @@ def read_rhs(cast):
615627
for ng in range(ngroups):
616628
fields.extend([t % (ng + 1) for t in tmp])
617629

618-
cls.field_list = [(cls.ftype, e) for e in fields]
630+
cls._detected_field_list[ds.unique_identifier] = [
631+
(cls.ftype, e) for e in fields
632+
]
619633

620634
cls.set_detected_fields(ds, fields)
621635
return fields
622636

623637
@classmethod
624638
def get_rt_parameters(cls, ds):
625-
if cls.rt_parameters:
626-
return cls.rt_parameters
639+
if cls.rt_parameters[ds.unique_identifier]:
640+
return cls.rt_parameters[ds.unique_identifier]
627641

628642
# Call detect fields to get the rt_parameters
629643
cls.detect_fields(ds)
630-
return cls.rt_parameters
644+
return cls.rt_parameters[ds.unique_identifier]

yt/frontends/ramses/tests/test_outputs.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,8 @@ def test_ramses_field_detection():
390390
fields_1 = set(DETECTED_FIELDS[ds1.unique_identifier]["ramses"])
391391

392392
# Check the right number of variables has been loaded
393-
assert P1["nvar"] == 10
394-
assert len(fields_1) == P1["nvar"]
393+
assert P1[ds1.unique_identifier]["nvar"] == 10
394+
assert len(fields_1) == P1[ds1.unique_identifier]["nvar"]
395395

396396
# Now load another dataset
397397
ds2 = yt.load(output_00080)
@@ -400,8 +400,8 @@ def test_ramses_field_detection():
400400
fields_2 = set(DETECTED_FIELDS[ds2.unique_identifier]["ramses"])
401401

402402
# Check the right number of variables has been loaded
403-
assert P2["nvar"] == 6
404-
assert len(fields_2) == P2["nvar"]
403+
assert P2[ds2.unique_identifier]["nvar"] == 6
404+
assert len(fields_2) == P2[ds2.unique_identifier]["nvar"]
405405

406406

407407
@requires_file(ramses_new_format)
@@ -794,3 +794,21 @@ def test_self_shielding_loading():
794794

795795
# Also make sure the difference is large for some cells
796796
assert (np.abs(diff) > 0.1).any()
797+
798+
799+
@requires_file(output_00080)
800+
@requires_file(ramses_mhd_128)
801+
def test_order_does_not_matter():
802+
for order in (1, 2):
803+
ds0 = yt.load(output_00080)
804+
ds1 = yt.load(ramses_mhd_128)
805+
806+
# This should not raise any exception
807+
if order == 1:
808+
_sp1 = ds1.all_data()
809+
sp0 = ds0.all_data()
810+
else:
811+
sp0 = ds0.all_data()
812+
_sp1 = ds1.all_data()
813+
814+
sp0["gas", "velocity_x"].max().to("km/s")

0 commit comments

Comments
 (0)