Skip to content

Commit f6db4a5

Browse files
Lachlan Groselachlangrose
authored andcommitted
fix: allow data to be specified in create and add function.
1 parent 8a2ec05 commit f6db4a5

File tree

3 files changed

+50
-50
lines changed

3 files changed

+50
-50
lines changed

LoopStructural/modelling/core/geological_model.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def __init__(
125125
self.features = []
126126
self.feature_name_index = {}
127127
self._data = pd.DataFrame() # None
128-
129128

130-
131129
self.stratigraphic_column = None
132130

133131
self.tol = 1e-10 * np.max(self.bounding_box.maximum - self.bounding_box.origin)
@@ -179,8 +177,40 @@ def __str__(self):
179177
def _ipython_key_completions_(self):
180178
return self.feature_name_index.keys()
181179

182-
180+
def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
181+
data = data.copy()
182+
data[['X', 'Y', 'Z']] = self.bounding_box.project(data[['X', 'Y', 'Z']].to_numpy())
183183

184+
if "type" in data:
185+
logger.warning("'type' is deprecated replace with 'feature_name' \n")
186+
data.rename(columns={"type": "feature_name"}, inplace=True)
187+
if "feature_name" not in data:
188+
logger.error("Data does not contain 'feature_name' column")
189+
raise BaseException("Cannot load data")
190+
for h in all_heading():
191+
if h not in data:
192+
data[h] = np.nan
193+
if h == "w":
194+
data[h] = 1.0
195+
if h == "coord":
196+
data[h] = 0
197+
if h == "polarity":
198+
data[h] = 1.0
199+
# LS wants polarity as -1 or 1, change 0 to -1
200+
data.loc[data["polarity"] == 0, "polarity"] = -1.0
201+
data.loc[np.isnan(data["w"]), "w"] = 1.0
202+
if "strike" in data and "dip" in data:
203+
logger.info("Converting strike and dip to vectors")
204+
mask = np.all(~np.isnan(data.loc[:, ["strike", "dip"]]), axis=1)
205+
data.loc[mask, gradient_vec_names()] = (
206+
strikedip2vector(data.loc[mask, "strike"], data.loc[mask, "dip"])
207+
* data.loc[mask, "polarity"].to_numpy()[:, None]
208+
)
209+
data.drop(["strike", "dip"], axis=1, inplace=True)
210+
data[['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz']] = data[
211+
['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz']
212+
].astype(float)
213+
return data
184214
@classmethod
185215
def from_processor(cls, processor):
186216
"""Builds a model from a :class:`LoopStructural.modelling.input.ProcessInputData` object
@@ -473,40 +503,9 @@ def data(self, data: pd.DataFrame):
473503
raise BaseException("Cannot load data")
474504
logger.info(f"Adding data to GeologicalModel with {len(data)} data points")
475505
self._data = data.copy()
476-
self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy())
477-
478-
479-
if "type" in self._data:
480-
logger.warning("'type' is deprecated replace with 'feature_name' \n")
481-
self._data.rename(columns={"type": "feature_name"}, inplace=True)
482-
if "feature_name" not in self._data:
483-
logger.error("Data does not contain 'feature_name' column")
484-
raise BaseException("Cannot load data")
485-
for h in all_heading():
486-
if h not in self._data:
487-
self._data[h] = np.nan
488-
if h == "w":
489-
self._data[h] = 1.0
490-
if h == "coord":
491-
self._data[h] = 0
492-
if h == "polarity":
493-
self._data[h] = 1.0
494-
# LS wants polarity as -1 or 1, change 0 to -1
495-
self._data.loc[self._data["polarity"] == 0, "polarity"] = -1.0
496-
self._data.loc[np.isnan(self._data["w"]), "w"] = 1.0
497-
if "strike" in self._data and "dip" in self._data:
498-
logger.info("Converting strike and dip to vectors")
499-
mask = np.all(~np.isnan(self._data.loc[:, ["strike", "dip"]]), axis=1)
500-
self._data.loc[mask, gradient_vec_names()] = (
501-
strikedip2vector(self._data.loc[mask, "strike"], self._data.loc[mask, "dip"])
502-
* self._data.loc[mask, "polarity"].to_numpy()[:, None]
503-
)
504-
self._data.drop(["strike", "dip"], axis=1, inplace=True)
505-
self._data[['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz']] = (
506-
self._data[
507-
['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz']
508-
].astype(float)
509-
)
506+
# self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy())
507+
508+
510509

511510
def set_model_data(self, data):
512511
logger.warning("deprecated method. Model data can now be set using the data attribute")
@@ -623,7 +622,7 @@ def create_and_add_foliation(
623622
if series_surface_data.shape[0] == 0:
624623
logger.warning("No data for {series_surface_data}, skipping")
625624
return
626-
series_builder.add_data_from_data_frame(series_surface_data)
625+
series_builder.add_data_from_data_frame(self.prepare_data(series_surface_data))
627626
self._add_faults(series_builder, features=faults)
628627

629628
# build feature
@@ -697,7 +696,7 @@ def create_and_add_fold_frame(
697696
if fold_frame_data.shape[0] == 0:
698697
logger.warning(f"No data for {fold_frame_name}, skipping")
699698
return
700-
fold_frame_builder.add_data_from_data_frame(fold_frame_data)
699+
fold_frame_builder.add_data_from_data_frame(self.prepare_data(fold_frame_data))
701700
self._add_faults(fold_frame_builder[0])
702701
self._add_faults(fold_frame_builder[1])
703702
self._add_faults(fold_frame_builder[2])
@@ -783,7 +782,10 @@ def create_and_add_folded_foliation(
783782
logger.warning(f"No data for {foliation_name}, skipping")
784783
return
785784
series_builder.add_data_from_data_frame(
786-
foliation_data )
785+
self.prepare_data(
786+
foliation_data
787+
)
788+
)
787789
self._add_faults(series_builder)
788790
# series_builder.add_data_to_interpolator(True)
789791
# build feature
@@ -878,7 +880,7 @@ def create_and_add_folded_fold_frame(
878880
)
879881
if fold_frame_data is None:
880882
fold_frame_data = self.data[self.data["feature_name"] == fold_frame_name]
881-
fold_frame_builder.add_data_from_data_frame(fold_frame_data)
883+
fold_frame_builder.add_data_from_data_frame(self.prepare_data(fold_frame_data))
882884

883885
for i in range(3):
884886
self._add_faults(fold_frame_builder[i])
@@ -1331,7 +1333,7 @@ def create_and_add_fault(
13311333
if fault_data.shape[0] == 0:
13321334
logger.warning(f"No data for {fault_name}, skipping")
13331335
return
1334-
1336+
13351337
self._add_faults(fault_frame_builder, features=faults)
13361338
# add data
13371339

@@ -1344,7 +1346,7 @@ def create_and_add_fault(
13441346
if intermediate_axis:
13451347
intermediate_axis = intermediate_axis
13461348
fault_frame_builder.create_data_from_geometry(
1347-
fault_frame_data=fault_data,
1349+
fault_frame_data=self.prepare_data(fault_data),
13481350
fault_center=fault_center,
13491351
fault_normal_vector=fault_normal_vector,
13501352
fault_slip_vector=fault_slip_vector,
@@ -1397,9 +1399,8 @@ def rescale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
13971399
points : np.array((N,3),dtype=double)
13981400
13991401
"""
1400-
1402+
14011403
return self.bounding_box.reproject(points,inplace=inplace)
1402-
14031404

14041405
# TODO move scale to bounding box/transformer
14051406
def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
@@ -1418,7 +1419,6 @@ def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14181419
14191420
"""
14201421
return self.bounding_box.project(np.array(points).astype(float),inplace=inplace)
1421-
14221422

14231423
def regular_grid(self, *, nsteps=None, shuffle=True, rescale=False, order="C"):
14241424
"""

tests/unit/modelling/intrusions/test_intrusions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def test_intrusion_builder():
6565
model.data = data
6666
model.nsteps = [10, 10, 10]
6767

68-
intrusion_data = data[data["feature_name"] == "tabular_intrusion"]
69-
intrusion_frame_data = model.data[model.data["feature_name"] == "tabular_intrusion_frame"]
70-
68+
intrusion_data = model.prepare_data(data[data["feature_name"] == "tabular_intrusion"])
69+
intrusion_frame_data = model.prepare_data(model.data[model.data["feature_name"] == "tabular_intrusion_frame"])
70+
7171
conformable_feature = model.create_and_add_foliation("stratigraphy")
7272

7373
intrusion_frame_parameters = {

tests/unit/modelling/test_geological_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_rescale_model_data():
1717
model.set_model_data(data)
1818
# Check that the model data is rescaled to local coordinates
1919
expected = data[['X', 'Y', 'Z']].values - bb[None, 0, :]
20-
actual = model.data[['X', 'Y', 'Z']].values
20+
actual = model.prepare_data(model.data)[['X', 'Y', 'Z']].values
2121
assert np.allclose(actual, expected, atol=1e-6)
2222
def test_access_feature_model():
2323
data, bb = load_claudius()

0 commit comments

Comments
 (0)