Skip to content

Commit cebbaee

Browse files
committed
adding test for 3D data
1 parent d3887c6 commit cebbaee

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Tue Aug 20 13:09:02 2024
4+
5+
@author: jpeacock
6+
"""
7+
8+
# =============================================================================
9+
# Imports
10+
# =============================================================================
11+
import unittest
12+
import numpy as np
13+
from mtpy import MTData
14+
15+
from mtpy_data import PROFILE_LIST
16+
from mtpy.modeling.simpeg.data_3d import Simpeg3DData
17+
from mtpy.modeling.simpeg.recipes.inversion_2d import Simpeg2D
18+
19+
# =============================================================================
20+
21+
22+
class TestSimpeg3DData(unittest.TestCase):
23+
@classmethod
24+
def setUpClass(self):
25+
self.md = MTData()
26+
self.md.add_station(
27+
[fn for fn in PROFILE_LIST if fn.name.startswith("16")]
28+
)
29+
# australian epsg
30+
self.md.utm_epsg = 4462
31+
32+
# interpolate onto a common period range
33+
self.new_periods = np.logspace(-5, 1, 10)
34+
self.md.interpolate(self.new_periods, inplace=True, bounds_error=False)
35+
36+
self.mt_df = self.md.to_dataframe()
37+
38+
def setUp(self):
39+
self.simpeg_data = Simpeg3DData(self.mt_df)
40+
41+
# def test_get_locations_fail(self):
42+
# df = self.md.to_dataframe()
43+
# df.profile_offset = 0
44+
# s = Simpeg3DData(df)
45+
# self.assertRaises(ValueError, getattr, s, "station_locations")
46+
47+
def test_station_locations(self):
48+
with self.subTest("shape"):
49+
self.assertEqual((6, 3), self.simpeg_data.station_locations.shape)
50+
51+
with self.subTest("elevation"):
52+
self.assertTrue(
53+
np.allclose(
54+
self.simpeg_data.station_locations[:, 1],
55+
np.array([210.0, 213.0, 212.0, 219.0, 214.0, 220.0]),
56+
)
57+
)
58+
59+
# def test_station_locations_no_elevation(self):
60+
# self.simpeg_data.include_elevation = False
61+
# with self.subTest("shape"):
62+
# self.assertEqual((6, 2), self.simpeg_data.station_locations.shape)
63+
64+
# with self.subTest("offset"):
65+
# self.assertTrue(
66+
# np.allclose(
67+
# self.simpeg_data.station_locations[:, 0],
68+
# np.array(
69+
# [
70+
# 0.0,
71+
# 479.36423899,
72+
# 1032.47570849,
73+
# 1526.02107079,
74+
# 2005.38361755,
75+
# 2501.76393224,
76+
# ]
77+
# ),
78+
# )
79+
# )
80+
# with self.subTest("elevation"):
81+
# self.assertTrue(
82+
# np.allclose(
83+
# self.simpeg_data.station_locations[:, 1],
84+
# np.zeros((6)),
85+
# )
86+
# )
87+
88+
def test_frequencies(self):
89+
self.assertTrue(
90+
np.allclose(1.0 / self.new_periods, self.simpeg_data.frequencies)
91+
)
92+
93+
94+
# def test_survey_te(self):
95+
# # simpeg sorts in order of lowest frequency to highest
96+
# with self.subTest("frequencies"):
97+
# self.assertTrue(
98+
# np.allclose(
99+
# 1.0 / self.new_periods[::-1],
100+
# self.simpeg_data.survey_te.frequencies,
101+
# )
102+
# )
103+
104+
# def test_survey_tm(self):
105+
# with self.subTest("frequencies"):
106+
# self.assertTrue(
107+
# np.allclose(
108+
# 1.0 / self.new_periods[::-1],
109+
# self.simpeg_data.survey_tm.frequencies,
110+
# )
111+
# )
112+
113+
# def test_te_observations(self):
114+
# with self.subTest("size"):
115+
# self.assertEqual(
116+
# self.simpeg_data.te_observations.size,
117+
# 2
118+
# * self.simpeg_data.n_frequencies
119+
# * self.simpeg_data.n_stations,
120+
# )
121+
122+
# def test_tm_observations(self):
123+
# with self.subTest("size"):
124+
# self.assertEqual(
125+
# self.simpeg_data.tm_observations.size,
126+
# 2
127+
# * self.simpeg_data.n_frequencies
128+
# * self.simpeg_data.n_stations,
129+
# )
130+
131+
# def test_te_data_errors(self):
132+
# with self.subTest("size"):
133+
# self.assertEqual(
134+
# self.simpeg_data.te_data_errors.size,
135+
# 2
136+
# * self.simpeg_data.n_frequencies
137+
# * self.simpeg_data.n_stations,
138+
# )
139+
140+
# def test_tm_data_errors(self):
141+
# with self.subTest("size"):
142+
# self.assertEqual(
143+
# self.simpeg_data.tm_data_errors.size,
144+
# 2
145+
# * self.simpeg_data.n_frequencies
146+
# * self.simpeg_data.n_stations,
147+
# )
148+
149+
150+
# class TestSimpeg2DRecipe(unittest.TestCase):
151+
# @classmethod
152+
# def setUpClass(self):
153+
# self.md = MTData()
154+
# self.md.add_station(
155+
# [fn for fn in PROFILE_LIST if fn.name.startswith("16")]
156+
# )
157+
# # australian epsg
158+
# self.md.utm_epsg = 4462
159+
160+
# # extract profile
161+
# self.profile = self.md.get_profile(
162+
# 149.15, -22.3257, 149.20, -22.3257, 1000
163+
# )
164+
# # interpolate onto a common period range
165+
# self.new_periods = np.logspace(-3, 0, 4)
166+
# self.profile.interpolate(
167+
# self.new_periods, inplace=True, bounds_error=False
168+
# )
169+
170+
# self.mt_df = self.profile.to_dataframe()
171+
172+
# self.simpeg_inversion = Simpeg2D(self.mt_df)
173+
174+
175+
# =============================================================================
176+
# run
177+
# =============================================================================
178+
if __name__ == "__main__":
179+
unittest.main()

0 commit comments

Comments
 (0)