Skip to content

Commit 19f99b1

Browse files
Merge pull request #5216 from matthewturk/pytest_experiment
2 parents be6eda5 + 72d445d commit 19f99b1

File tree

3 files changed

+178
-173
lines changed

3 files changed

+178
-173
lines changed

nose_ignores.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@
5050
--ignore-file=test_sph_pixelization_pytestonly\.py
5151
--ignore-file=test_time_series\.py
5252
--ignore-file=test_cf_radial_pytest\.py
53+
--ignore-file=test_data_containers\.py

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ addopts = '''
306306
--ignore-glob='/*_nose.py'
307307
--ignore-glob='/*/yt/data_objects/level_sets/tests/test_clump_finding.py'
308308
--ignore-glob='/*/yt/data_objects/tests/test_connected_sets.py'
309-
--ignore-glob='/*/yt/data_objects/tests/test_data_containers.py'
310309
--ignore-glob='/*/yt/data_objects/tests/test_dataset_access.py'
311310
--ignore-glob='/*/yt/data_objects/tests/test_particle_filter.py'
312311
--ignore-glob='/*/yt/data_objects/tests/test_particle_trajectories.py'
Lines changed: 177 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import os
2-
import shutil
3-
import tempfile
4-
import unittest
52

63
import numpy as np
7-
from nose.tools import assert_raises
4+
import pytest
85
from numpy.testing import assert_array_equal, assert_equal
96

107
from yt.data_objects.data_containers import YTDataContainer
@@ -13,179 +10,187 @@
1310
fake_amr_ds,
1411
fake_particle_ds,
1512
fake_random_ds,
16-
requires_module,
13+
requires_module_pytest as requires_module,
1714
)
1815
from yt.utilities.exceptions import YTException, YTFieldNotFound
1916

2017

21-
class TestDataContainers(unittest.TestCase):
22-
@classmethod
23-
def setUpClass(cls):
24-
cls.tmpdir = tempfile.mkdtemp()
25-
cls.curdir = os.getcwd()
26-
os.chdir(cls.tmpdir)
18+
@pytest.fixture
19+
def temp_workdir(tmp_path):
20+
curdir = os.getcwd()
21+
os.chdir(tmp_path)
22+
yield tmp_path
23+
os.chdir(curdir)
2724

28-
@classmethod
29-
def tearDownClass(cls):
30-
os.chdir(cls.curdir)
31-
shutil.rmtree(cls.tmpdir)
3225

33-
def test_yt_data_container(self):
34-
# Test if ds could be None
35-
with assert_raises(RuntimeError) as err:
36-
YTDataContainer(None, None)
37-
desired = (
26+
@pytest.mark.usefixtures("temp_workdir")
27+
def test_yt_data_container():
28+
# Test if ds could be None
29+
with pytest.raises(
30+
RuntimeError,
31+
match=(
3832
"Error: ds must be set either through class"
3933
" type or parameter to the constructor"
40-
)
41-
assert_equal(str(err.exception), desired)
42-
43-
# Test if field_data key exists
44-
ds = fake_random_ds(5)
45-
proj = ds.proj(("gas", "density"), 0, data_source=ds.all_data())
46-
assert_equal("px" in proj.keys(), True)
47-
assert_equal("pz" in proj.keys(), False)
48-
49-
# Delete the key and check if exits
50-
del proj["px"]
51-
assert_equal("px" in proj.keys(), False)
52-
del proj["gas", "density"]
53-
assert_equal("density" in proj.keys(), False)
54-
55-
# Delete a non-existent field
56-
with assert_raises(YTFieldNotFound) as ex:
57-
del proj["p_mass"]
58-
desired = "Could not find field 'p_mass' in UniformGridData."
59-
assert_equal(str(ex.exception), desired)
60-
61-
def test_write_out(self):
62-
filename = "sphere.txt"
63-
ds = fake_random_ds(16, particles=10)
64-
sp = ds.sphere(ds.domain_center, 0.25)
65-
66-
sp.write_out(filename, fields=[("gas", "cell_volume")])
67-
68-
with open(filename) as file:
69-
file_row_1 = file.readline()
70-
file_row_2 = file.readline()
71-
file_row_2 = np.array(file_row_2.split("\t"), dtype=np.float64)
72-
sorted_keys = sorted(sp.field_data.keys())
73-
keys = [str(k) for k in sorted_keys]
74-
keys = "\t".join(["#"] + keys + ["\n"])
75-
data = [sp.field_data[k][0] for k in sorted_keys]
76-
77-
assert_equal(keys, file_row_1)
78-
assert_array_equal(data, file_row_2)
79-
80-
def test_invalid_write_out(self):
81-
filename = "sphere.txt"
82-
ds = fake_random_ds(16, particles=10)
83-
sp = ds.sphere(ds.domain_center, 0.25)
84-
85-
with assert_raises(YTException):
86-
sp.write_out(filename, fields=[("all", "particle_ones")])
87-
88-
@requires_module("pandas")
89-
def test_to_dataframe(self):
90-
fields = [("gas", "density"), ("gas", "velocity_z")]
91-
ds = fake_random_ds(6)
92-
dd = ds.all_data()
93-
df = dd.to_dataframe(fields)
94-
assert_array_equal(dd[fields[0]], df[fields[0][1]])
95-
assert_array_equal(dd[fields[1]], df[fields[1][1]])
96-
97-
@requires_module("astropy")
98-
def test_to_astropy_table(self):
99-
from yt.units.yt_array import YTArray
100-
101-
fields = [("gas", "density"), ("gas", "velocity_z")]
102-
ds = fake_random_ds(6)
103-
dd = ds.all_data()
104-
at1 = dd.to_astropy_table(fields)
105-
assert_array_equal(dd[fields[0]].d, at1[fields[0][1]].value)
106-
assert_array_equal(dd[fields[1]].d, at1[fields[1][1]].value)
107-
assert dd[fields[0]].units == YTArray.from_astropy(at1[fields[0][1]]).units
108-
assert dd[fields[1]].units == YTArray.from_astropy(at1[fields[1][1]]).units
109-
110-
def test_std(self):
111-
ds = fake_random_ds(3)
112-
ds.all_data().std(("gas", "density"), weight=("gas", "velocity_z"))
113-
114-
def test_to_frb(self):
115-
# Test cylindrical geometry
116-
fields = ["density", "cell_mass"]
117-
units = ["g/cm**3", "g"]
118-
ds = fake_amr_ds(
119-
fields=fields, units=units, geometry="cylindrical", particles=16**3
120-
)
121-
dd = ds.all_data()
122-
proj = ds.proj(
123-
("gas", "density"),
124-
weight_field=("gas", "cell_mass"),
125-
axis=1,
126-
data_source=dd,
127-
)
128-
frb = proj.to_frb((1.0, "unitary"), 64)
129-
assert_equal(frb.radius, (1.0, "unitary"))
130-
assert_equal(frb.buff_size, 64)
131-
132-
def test_extract_isocontours(self):
133-
# Test isocontour properties for AMRGridData
134-
fields = ["density", "cell_mass"]
135-
units = ["g/cm**3", "g"]
136-
ds = fake_amr_ds(fields=fields, units=units, particles=16**3)
137-
dd = ds.all_data()
138-
q = dd.quantities["WeightedAverageQuantity"]
139-
rho = q(("gas", "density"), weight=("gas", "cell_mass"))
140-
dd.extract_isocontours(("gas", "density"), rho, "triangles.obj", True)
141-
dd.calculate_isocontour_flux(
142-
("gas", "density"),
143-
rho,
144-
("index", "x"),
145-
("index", "y"),
146-
("index", "z"),
147-
("index", "dx"),
148-
)
149-
150-
# Test error in case of ParticleData
151-
ds = fake_particle_ds()
152-
dd = ds.all_data()
153-
q = dd.quantities["WeightedAverageQuantity"]
154-
rho = q(("all", "particle_velocity_x"), weight=("all", "particle_mass"))
155-
with assert_raises(NotImplementedError):
156-
dd.extract_isocontours("density", rho, sample_values="x")
157-
158-
def test_derived_field(self):
159-
# Test that derived field on filtered particles do not require
160-
# their parent field to be created
161-
ds = fake_particle_ds()
162-
dd = ds.all_data()
163-
dd.set_field_parameter("axis", 0)
164-
165-
@particle_filter(requires=["particle_mass"], filtered_type="io")
166-
def massive(pfilter, data):
167-
return data[pfilter.filtered_type, "particle_mass"].to("code_mass") > 0.5
168-
169-
ds.add_particle_filter("massive")
170-
171-
def fun(field, data):
172-
return data[field.name[0], "particle_mass"]
173-
174-
# Add the field to the massive particles
175-
ds.add_field(
176-
("massive", "test"),
177-
function=fun,
178-
sampling_type="particle",
179-
units="code_mass",
180-
)
181-
182-
expected_size = (dd["io", "particle_mass"].to("code_mass") > 0.5).sum()
183-
184-
fields_to_test = [f for f in ds.derived_field_list if f[0] == "massive"]
185-
186-
def test_this(fname):
187-
data = dd[fname]
188-
assert_equal(data.shape[0], expected_size)
189-
190-
for fname in fields_to_test:
191-
test_this(fname)
34+
),
35+
):
36+
YTDataContainer(None, None)
37+
38+
# Test if field_data key exists
39+
ds = fake_random_ds(5)
40+
proj = ds.proj(("gas", "density"), 0, data_source=ds.all_data())
41+
assert "px" in proj.keys()
42+
assert "pz" not in proj.keys()
43+
44+
# Delete the key and check if exits
45+
del proj["px"]
46+
assert "px" not in proj.keys()
47+
del proj["gas", "density"]
48+
assert "density" not in proj.keys()
49+
50+
# Delete a non-existent field
51+
with pytest.raises(
52+
YTFieldNotFound, match="Could not find field 'p_mass' in UniformGridData."
53+
):
54+
del proj["p_mass"]
55+
56+
57+
@pytest.mark.usefixtures("temp_workdir")
58+
def test_write_out():
59+
filename = "sphere.txt"
60+
ds = fake_random_ds(16, particles=10)
61+
sp = ds.sphere(ds.domain_center, 0.25)
62+
63+
sp.write_out(filename, fields=[("gas", "cell_volume")])
64+
65+
with open(filename) as file:
66+
file_row_1 = file.readline()
67+
file_row_2 = file.readline()
68+
file_row_2 = np.array(file_row_2.split("\t"), dtype=np.float64)
69+
sorted_keys = sorted(sp.field_data.keys())
70+
keys = [str(k) for k in sorted_keys]
71+
keys = "\t".join(["#"] + keys + ["\n"])
72+
data = [sp.field_data[k][0] for k in sorted_keys]
73+
74+
assert_equal(keys, file_row_1)
75+
assert_array_equal(data, file_row_2)
76+
77+
78+
@pytest.mark.usefixtures("temp_workdir")
79+
def test_invalid_write_out():
80+
filename = "sphere.txt"
81+
ds = fake_random_ds(16, particles=10)
82+
sp = ds.sphere(ds.domain_center, 0.25)
83+
84+
with pytest.raises(YTException):
85+
sp.write_out(filename, fields=[("all", "particle_ones")])
86+
87+
88+
@pytest.mark.usefixtures("temp_workdir")
89+
@requires_module("pandas")
90+
def test_to_dataframe():
91+
fields = [("gas", "density"), ("gas", "velocity_z")]
92+
ds = fake_random_ds(6)
93+
dd = ds.all_data()
94+
df = dd.to_dataframe(fields)
95+
assert_array_equal(dd[fields[0]], df[fields[0][1]])
96+
assert_array_equal(dd[fields[1]], df[fields[1][1]])
97+
98+
99+
@pytest.mark.usefixtures("temp_workdir")
100+
@requires_module("astropy")
101+
def test_to_astropy_table():
102+
from yt.units.yt_array import YTArray
103+
104+
fields = [("gas", "density"), ("gas", "velocity_z")]
105+
ds = fake_random_ds(6)
106+
dd = ds.all_data()
107+
at1 = dd.to_astropy_table(fields)
108+
assert_array_equal(dd[fields[0]].d, at1[fields[0][1]].value)
109+
assert_array_equal(dd[fields[1]].d, at1[fields[1][1]].value)
110+
assert dd[fields[0]].units == YTArray.from_astropy(at1[fields[0][1]]).units
111+
assert dd[fields[1]].units == YTArray.from_astropy(at1[fields[1][1]]).units
112+
113+
114+
def test_std():
115+
ds = fake_random_ds(3)
116+
ds.all_data().std(("gas", "density"), weight=("gas", "velocity_z"))
117+
118+
119+
def test_to_frb():
120+
# Test cylindrical geometry
121+
fields = ["density", "cell_mass"]
122+
units = ["g/cm**3", "g"]
123+
ds = fake_amr_ds(
124+
fields=fields, units=units, geometry="cylindrical", particles=16**3
125+
)
126+
dd = ds.all_data()
127+
proj = ds.proj(
128+
("gas", "density"),
129+
weight_field=("gas", "cell_mass"),
130+
axis=1,
131+
data_source=dd,
132+
)
133+
frb = proj.to_frb((1.0, "unitary"), 64)
134+
assert frb.radius == (1.0, "unitary")
135+
assert frb.buff_size == 64
136+
137+
138+
@pytest.mark.usefixtures("temp_workdir")
139+
def test_extract_isocontours():
140+
# Test isocontour properties for AMRGridData
141+
fields = ["density", "cell_mass"]
142+
units = ["g/cm**3", "g"]
143+
ds = fake_amr_ds(fields=fields, units=units, particles=16**3)
144+
dd = ds.all_data()
145+
q = dd.quantities["WeightedAverageQuantity"]
146+
rho = q(("gas", "density"), weight=("gas", "cell_mass"))
147+
dd.extract_isocontours(("gas", "density"), rho, "triangles.obj", True)
148+
dd.calculate_isocontour_flux(
149+
("gas", "density"),
150+
rho,
151+
("index", "x"),
152+
("index", "y"),
153+
("index", "z"),
154+
("index", "dx"),
155+
)
156+
157+
# Test error in case of ParticleData
158+
ds = fake_particle_ds()
159+
dd = ds.all_data()
160+
q = dd.quantities["WeightedAverageQuantity"]
161+
rho = q(("all", "particle_velocity_x"), weight=("all", "particle_mass"))
162+
with pytest.raises(NotImplementedError):
163+
dd.extract_isocontours("density", rho, sample_values="x")
164+
165+
166+
def test_derived_field():
167+
# Test that derived field on filtered particles do not require
168+
# their parent field to be created
169+
ds = fake_particle_ds()
170+
dd = ds.all_data()
171+
dd.set_field_parameter("axis", 0)
172+
173+
@particle_filter(requires=["particle_mass"], filtered_type="io")
174+
def massive(pfilter, data):
175+
return data[pfilter.filtered_type, "particle_mass"].to("code_mass") > 0.5
176+
177+
ds.add_particle_filter("massive")
178+
179+
def fun(field, data):
180+
return data[field.name[0], "particle_mass"]
181+
182+
# Add the field to the massive particles
183+
ds.add_field(
184+
("massive", "test"),
185+
function=fun,
186+
sampling_type="particle",
187+
units="code_mass",
188+
)
189+
190+
expected_size = (dd["io", "particle_mass"].to("code_mass") > 0.5).sum()
191+
192+
fields_to_test = [f for f in ds.derived_field_list if f[0] == "massive"]
193+
194+
for fname in fields_to_test:
195+
data = dd[fname]
196+
assert_equal(data.shape[0], expected_size)

0 commit comments

Comments
 (0)