Skip to content

Commit 86b2a1a

Browse files
committed
improving interpolation
1 parent 429a8f4 commit 86b2a1a

File tree

2 files changed

+10
-251
lines changed

2 files changed

+10
-251
lines changed

mtpy/core/mt.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,6 @@ def interpolate(
548548
method="slinear",
549549
bounds_error=True,
550550
f_type="period",
551-
z_log_space=False,
552551
**kwargs,
553552
):
554553
"""Interpolate the impedance tensor onto different frequencies.
@@ -603,10 +602,7 @@ def interpolate(
603602

604603
new_m = self.clone_empty()
605604
if self.has_impedance():
606-
# new_m.Z = self.Z.interpolate(
607-
# new_period, method=method, log_space=z_log_space, **kwargs
608-
# )
609-
new_m.Z = self.Z.interpolate_improved(new_period, method=method, **kwargs)
605+
new_m.Z = self.Z.interpolate(new_period, method=method, **kwargs)
610606
if new_m.has_impedance():
611607
if np.all(np.isnan(new_m.Z.z)):
612608
self.logger.warning(
@@ -615,12 +611,7 @@ def interpolate(
615611
"See scipy.interpolate.interp1d for more information."
616612
)
617613
if self.has_tipper():
618-
# new_m.Tipper = self.Tipper.interpolate(
619-
# new_period, method=method, **kwargs
620-
# )
621-
new_m.Tipper = self.Tipper.interpolate_improved(
622-
new_period, method=method, **kwargs
623-
)
614+
new_m.Tipper = self.Tipper.interpolate(new_period, method=method, **kwargs)
624615
if new_m.has_tipper():
625616
if np.all(np.isnan(new_m.Tipper.tipper)):
626617
self.logger.warning(

mtpy/core/transfer_function/base.py

Lines changed: 8 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def get_clockwise(coordinate_reference_frame):
595595
tb._dataset = ds
596596
return tb
597597

598-
def interpolate_improved(
598+
def interpolate(
599599
self,
600600
new_periods: np.ndarray,
601601
inplace: bool = False,
@@ -744,14 +744,6 @@ def interpolate_improved(
744744
)
745745
else:
746746
# Handle real data
747-
# interp_func = interpolate.interp1d(
748-
# valid_periods,
749-
# valid_data,
750-
# kind=method,
751-
# bounds_error=False,
752-
# fill_value="extrapolate" if extrapolate else np.nan,
753-
# **kwargs,
754-
# )
755747
interp_func = self._get_interpolator(
756748
valid_periods,
757749
valid_data,
@@ -837,8 +829,6 @@ def _get_interpolator(
837829
- "pchip"
838830
- "spline"
839831
- "akima"
840-
- "barycentric"
841-
- "krogh"
842832
- "polynomial"
843833
:param kwargs: additional arguments for interpolator
844834
:return: interpolation function
@@ -880,244 +870,22 @@ def _get_interpolator(
880870
if extrapolate:
881871
# Wrap with extrapolate=True
882872
return lambda xx: interp(xx, extrapolate=True)
883-
else:
884-
return interp
885-
elif method == "barycentric":
886-
# Barycentric doesn't support extrapolation, will need custom handling
887-
interp = interpolate.BarycentricInterpolator(x, y, **kwargs)
888-
if extrapolate:
889-
raise ValueError(
890-
"barycentric interpolation does not support extrapolation."
891-
)
892-
# # Create a function that extends beyond the bounds
893-
# def interp_func(xx):
894-
# mask = (xx < x.min()) | (xx > x.max())
895-
# result = np.empty_like(xx, dtype=float)
896-
# result[~mask] = interp(xx[~mask])
897-
898-
# # Extrapolate using the closest points
899-
# if np.any(xx < x.min()):
900-
# left_idx = xx < x.min()
901-
# slope = (y[1] - y[0]) / (x[1] - x[0])
902-
# result[left_idx] = y[0] + slope * (xx[left_idx] - x[0])
903-
904-
# if np.any(xx > x.max()):
905-
# right_idx = xx > x.max()
906-
# slope = (y[-1] - y[-2]) / (x[-1] - x[-2])
907-
# result[right_idx] = y[-1] + slope * (xx[right_idx] - x[-1])
908-
909-
# return result
910-
911-
return interp_func
912-
else:
913-
# Return original interpolator
914-
return interp
915-
elif method == "krogh":
916-
# Krogh doesn't support extrapolation directly
917-
interp = interpolate.KroghInterpolator(x, y, **kwargs)
918-
if extrapolate:
919-
raise ValueError("Krogh interpolation does not support extrapolation.")
920-
# # Create a function that extends beyond the bounds
921-
# def interp_func(xx):
922-
# mask = (xx < x.min()) | (xx > x.max())
923-
# result = np.empty_like(xx, dtype=float)
924-
# result[~mask] = interp(xx[~mask])
925-
926-
# # Extrapolate using the closest points
927-
# if np.any(xx < x.min()):
928-
# left_idx = xx < x.min()
929-
# slope = (y[1] - y[0]) / (x[1] - x[0])
930-
# result[left_idx] = y[0] + slope * (xx[left_idx] - x[0])
931-
932-
# if np.any(xx > x.max()):
933-
# right_idx = xx > x.max()
934-
# slope = (y[-1] - y[-2]) / (x[-1] - x[-2])
935-
# result[right_idx] = y[-1] + slope * (xx[right_idx] - x[-1])
936-
937-
# return result
938-
939-
# return interp_func
940-
else:
941-
# Return original interpolator
942-
return interp
873+
return interp
943874
elif method == "polynomial":
944875
# Use CubicSpline instead of polynomial for better handling of extrapolation
945876
return interpolate.CubicSpline(x, y, extrapolate=extrapolate, **kwargs)
946-
else:
877+
elif method in ["linear", "cubic", "nearest", "slinear"]:
947878
# Default to general interp1d for methods like linear, cubic, etc.
948879
fill_value = "extrapolate" if extrapolate else np.nan
949880
return interpolate.interp1d(
950881
x, y, kind=method, bounds_error=False, fill_value=fill_value, **kwargs
951882
)
952-
953-
def interpolate(
954-
self,
955-
new_periods,
956-
inplace=False,
957-
method="slinear",
958-
na_method="pchip",
959-
log_space=False,
960-
extrapolate=False,
961-
**kwargs,
962-
):
963-
"""Interpolate onto a new period range.
964-
965-
The way this works is that NaNs
966-
are first interpolated using method `na_method` along the original
967-
period map. This allows us to use xarray tools for interpolation. If we
968-
drop NaNs using xarray it drops each column or row that has a single
969-
NaN and removes way too much data. Therefore interpolating NaNs first
970-
keeps most of the data. Then a 1D interpolation is done for the
971-
`new_periods` using method `method`.
972-
973-
'pchip' seems to work best when using xr.interpolate_na
974-
975-
Set log_space=True if the object being interpolated is in log space,
976-
like impedance. It seems that functions that are naturally in log-space
977-
cause issues with the interpolators so taking the log of the function
978-
seems to produce better results.
979-
:param new_periods: New periods to interpolate on to.
980-
:type new_periods: np.ndarray, list
981-
:param inplace: Interpolate inplace, defaults to False.
982-
:type inplace: bool, optional
983-
:param method: Method for 1D linear interpolation options are
984-
["linear", "nearest", "zero", "slinear", "quadratic", "cubic"],, defaults to "slinear".
985-
:type method: string, optional
986-
:param na_method: Method to interpolate NaNs along original periods
987-
options are {"linear", "nearest", "zero", "slinear", "quadratic",
988-
"cubic", "polynomial", "barycentric", "krogh", "pchip", "spline",
989-
"akima"}, defaults to "pchip".
990-
:type na_method: string, optional
991-
:param log_space: Set to true if function is naturally logarithmic,, defaults to False.
992-
:type log_space: bool, optional
993-
:param extrapolate: Extrapolate past original period range, default is
994-
False. If set to True be careful cause the values are not great, defaults to False.
995-
:type extrapolate: bool, optional
996-
:param **kwargs: Keyword args passed to interpolation methods.
997-
:type **kwargs: dict
998-
:return: Interpolated object.
999-
:rtype: :class:`mtpy.core.transfer_fuction.base.TFBase`
1000-
"""
1001-
1002-
da_dict = {}
1003-
for key in self._dataset.data_vars:
1004-
# need to interpolate over nans first, if use dropna loose a lot
1005-
# of data. going to interpolate anyway. pchip seems to work best
1006-
# for interpolate na. If this doesn't work think about
1007-
# interpolating component by component.
1008-
if log_space:
1009-
da_drop_nan = np.log(self._dataset[key]).interpolate_na(
1010-
dim="period", method=na_method
1011-
)
1012-
da_dict[key] = np.exp(
1013-
da_drop_nan.interp(period=new_periods, method=method, kwargs=kwargs)
1014-
)
1015-
else:
1016-
da_drop_nan = self._dataset[key].interpolate_na(
1017-
dim="period", method=na_method
1018-
)
1019-
da_dict[key] = da_drop_nan.interp(
1020-
period=new_periods, method=method, kwargs=kwargs
1021-
)
1022-
1023-
# need to abide by the original data that has nans at the
1024-
# beginning and end of the data. interpolate_na will remove these
1025-
# and gives terrible values, so fill these back in with nans
1026-
if not extrapolate:
1027-
da_dict[key] = self._backfill_nans(self._dataset[key], da_dict[key])
1028-
1029-
ds = xr.Dataset(da_dict)
1030-
1031-
if inplace:
1032-
self._dataset = ds
1033883
else:
1034-
tb = self.copy()
1035-
tb._dataset = ds
1036-
return tb
1037-
1038-
@staticmethod
1039-
def _find_nans_index(data_array):
1040-
"""Find nans at beginning and end of xarray.
1041-
1042-
When you interpolate a
1043-
xarray.DataArray we interpolate nans, which removes the original nans
1044-
at the beginning and end of the array. We need to find these
1045-
indicies and period min and max so we can put them back in.
1046-
:param data_array: DESCRIPTION.
1047-
:type data_array: TYPE
1048-
:return: DESCRIPTION.
1049-
:rtype: TYPE
1050-
"""
1051-
1052-
index_list = []
1053-
for ch_in in data_array.input.data:
1054-
for ch_out in data_array.output.data:
1055-
index = np.where(
1056-
np.nan_to_num(data_array.loc[{"input": ch_in, "output": ch_out}])
1057-
== 0
1058-
)[0]
1059-
if len(index) > 0:
1060-
entry = {
1061-
"input": ch_in,
1062-
"output": ch_out,
1063-
"beginning": [],
1064-
"end": [],
1065-
"period_min": float(data_array.period.min()),
1066-
"period_max": float(data_array.period.max()),
1067-
}
1068-
if index[0] == 0:
1069-
entry["beginning"] = []
1070-
ii = 0
1071-
diff = 1
1072-
while diff == 1 and ii < len(index) - 1:
1073-
diff = index[ii + 1] - index[ii]
1074-
entry["beginning"].append(ii)
1075-
ii += 1
1076-
if len(entry["beginning"]) > 0:
1077-
entry["period_min"] = float(
1078-
data_array.period[entry["beginning"][-1] + 1]
1079-
)
1080-
1081-
if index[-1] == (data_array.shape[0] - 1):
1082-
entry["end"] = [-1]
1083-
ii = -2
1084-
diff = 1
1085-
while diff == 1 and abs(ii) < len(index):
1086-
diff = abs(index[ii - 1] - index[ii])
1087-
entry["end"].append(ii)
1088-
ii -= 1
1089-
entry["period_max"] = float(
1090-
data_array.period[entry["end"][-1] - 1]
1091-
)
1092-
index_list.append(entry)
1093-
1094-
return index_list
1095-
1096-
def _backfill_nans(self, original, interpolated):
1097-
"""Back fill with nans for extrapolated values.
1098-
:param original: Original data array.
1099-
:type original: xarray.DataArray
1100-
:param interpolated: Interpolated data array.
1101-
:type interpolated: xr.DataArray
1102-
:return: Nan's filled in from beginning and end of original data.
1103-
:rtype: xarray.DataArray
1104-
"""
1105-
1106-
original_index = self._find_nans_index(original)
1107-
for entry in original_index:
1108-
beginning = np.where(interpolated.period < entry["period_min"])
1109-
end = np.where(interpolated.period > entry["period_max"])
1110-
nan_index = np.append(beginning[0], end[0])
1111-
if "complex" in interpolated.dtype.name:
1112-
interpolated.loc[{"input": entry["input"], "output": entry["output"]}][
1113-
nan_index
1114-
] = (np.nan + 1j * np.nan)
1115-
else:
1116-
interpolated.loc[{"input": entry["input"], "output": entry["output"]}][
1117-
nan_index
1118-
] = np.nan
1119-
1120-
return interpolated
884+
raise ValueError(
885+
f"Interpolation method {method} is not supported. "
886+
"Supported methods are linear, cubic, nearest, slinear, "
887+
"pchip, spline, akima, polynomial."
888+
)
1121889

1122890
def to_xarray(self):
1123891
"""To an xarray dataset.

0 commit comments

Comments
 (0)