|
| 1 | +from dataclasses import dataclass |
| 2 | + |
1 | 3 | import numpy as np
|
2 | 4 | from numpy.testing import (
|
3 | 5 | assert_almost_equal,
|
|
8 | 10 | )
|
9 | 11 |
|
10 | 12 | from yt import load
|
| 13 | +from yt.data_objects.static_output import Dataset |
11 | 14 | from yt.frontends.stream.fields import StreamFieldInfo
|
12 | 15 | from yt.testing import (
|
13 | 16 | assert_allclose_units,
|
@@ -59,61 +62,15 @@ def _strip_ftype(field):
|
59 | 62 | return field[1]
|
60 | 63 |
|
61 | 64 |
|
62 |
| -class TestFieldAccess: |
63 |
| - description = None |
64 |
| - |
65 |
| - def __init__(self, field_name, ds, nprocs): |
66 |
| - # Note this should be a field name |
67 |
| - self.field_name = field_name |
68 |
| - self.description = f"Accessing_{field_name}_{nprocs}" |
69 |
| - self.nprocs = nprocs |
70 |
| - self.ds = ds |
71 |
| - |
72 |
| - def __call__(self): |
73 |
| - field = self.ds._get_field_info(self.field_name) |
74 |
| - skip_grids = False |
75 |
| - needs_spatial = False |
76 |
| - for v in field.validators: |
77 |
| - if getattr(v, "ghost_zones", 0) > 0: |
78 |
| - skip_grids = True |
79 |
| - if hasattr(v, "ghost_zones"): |
80 |
| - needs_spatial = True |
81 |
| - |
82 |
| - ds = self.ds |
83 |
| - |
84 |
| - # This gives unequal sized grids as well as subgrids |
85 |
| - dd1 = ds.all_data() |
86 |
| - dd2 = ds.all_data() |
87 |
| - sp = get_params(ds) |
88 |
| - dd1.field_parameters.update(sp) |
89 |
| - dd2.field_parameters.update(sp) |
90 |
| - with np.errstate(all="ignore"): |
91 |
| - v1 = dd1[self.field_name] |
92 |
| - # No more conversion checking |
93 |
| - assert_equal(v1, dd1[self.field_name]) |
94 |
| - if not needs_spatial: |
95 |
| - with field.unit_registry(dd2): |
96 |
| - res = field._function(field, dd2) |
97 |
| - res = dd2.apply_units(res, field.units) |
98 |
| - assert_array_almost_equal_nulp(v1, res, 4) |
99 |
| - if not skip_grids: |
100 |
| - for g in ds.index.grids: |
101 |
| - g.field_parameters.update(sp) |
102 |
| - v1 = g[self.field_name] |
103 |
| - g.clear_data() |
104 |
| - g.field_parameters.update(sp) |
105 |
| - r1 = field._function(field, g) |
106 |
| - if field.sampling_type == "particle": |
107 |
| - assert_equal(v1.shape[0], g.NumberOfParticles) |
108 |
| - else: |
109 |
| - assert_array_equal(r1.shape, v1.shape) |
110 |
| - for ax in "xyz": |
111 |
| - assert_array_equal(g["index", ax].shape, v1.shape) |
112 |
| - with field.unit_registry(g): |
113 |
| - res = field._function(field, g) |
114 |
| - assert_array_equal(v1.shape, res.shape) |
115 |
| - res = g.apply_units(res, field.units) |
116 |
| - assert_array_almost_equal_nulp(v1, res, 4) |
| 65 | +@dataclass(slots=True, frozen=True) |
| 66 | +class FieldAccessTestCase: |
| 67 | + field_name: str |
| 68 | + ds: Dataset |
| 69 | + nprocs: int |
| 70 | + |
| 71 | + @property |
| 72 | + def description(self) -> str: |
| 73 | + return f"Accessing_{self.field_name}_{self.nprocs}" |
117 | 74 |
|
118 | 75 |
|
119 | 76 | def get_base_ds(nprocs):
|
@@ -188,7 +145,53 @@ def test_all_fields():
|
188 | 145 |
|
189 | 146 | for nprocs in [1, 4, 8]:
|
190 | 147 | test_all_fields.__name__ = f"{field}_{nprocs}"
|
191 |
| - yield TestFieldAccess(field, datasets[nprocs], nprocs) |
| 148 | + |
| 149 | + tc = FieldAccessTestCase(field, datasets[nprocs], nprocs) |
| 150 | + |
| 151 | + field = tc.ds._get_field_info(tc.field_name) |
| 152 | + skip_grids = False |
| 153 | + needs_spatial = False |
| 154 | + for v in field.validators: |
| 155 | + if getattr(v, "ghost_zones", 0) > 0: |
| 156 | + skip_grids = True |
| 157 | + if hasattr(v, "ghost_zones"): |
| 158 | + needs_spatial = True |
| 159 | + |
| 160 | + ds = tc.ds |
| 161 | + |
| 162 | + # This gives unequal sized grids as well as subgrids |
| 163 | + dd1 = ds.all_data() |
| 164 | + dd2 = ds.all_data() |
| 165 | + sp = get_params(ds) |
| 166 | + dd1.field_parameters.update(sp) |
| 167 | + dd2.field_parameters.update(sp) |
| 168 | + with np.errstate(all="ignore"): |
| 169 | + v1 = dd1[tc.field_name] |
| 170 | + # No more conversion checking |
| 171 | + assert_equal(v1, dd1[tc.field_name]) |
| 172 | + if not needs_spatial: |
| 173 | + with field.unit_registry(dd2): |
| 174 | + res = field._function(field, dd2) |
| 175 | + res = dd2.apply_units(res, field.units) |
| 176 | + assert_array_almost_equal_nulp(v1, res, 4) |
| 177 | + if not skip_grids: |
| 178 | + for g in ds.index.grids: |
| 179 | + g.field_parameters.update(sp) |
| 180 | + v1 = g[tc.field_name] |
| 181 | + g.clear_data() |
| 182 | + g.field_parameters.update(sp) |
| 183 | + r1 = field._function(field, g) |
| 184 | + if field.sampling_type == "particle": |
| 185 | + assert_equal(v1.shape[0], g.NumberOfParticles) |
| 186 | + else: |
| 187 | + assert_array_equal(r1.shape, v1.shape) |
| 188 | + for ax in "xyz": |
| 189 | + assert_array_equal(g["index", ax].shape, v1.shape) |
| 190 | + with field.unit_registry(g): |
| 191 | + res = field._function(field, g) |
| 192 | + assert_array_equal(v1.shape, res.shape) |
| 193 | + res = g.apply_units(res, field.units) |
| 194 | + assert_array_almost_equal_nulp(v1, res, 4) |
192 | 195 |
|
193 | 196 |
|
194 | 197 | def test_add_deposited_particle_field():
|
|
0 commit comments