Skip to content

Commit 57b26f7

Browse files
feat: add bidirectional option to filter_by_displacement
1 parent d2cd4e6 commit 57b26f7

File tree

1 file changed

+117
-14
lines changed

1 file changed

+117
-14
lines changed

movement/filtering.py

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import xarray as xr
77
from scipy import signal
88

9-
from movement.utils.logging import log_error, log_to_attrs
9+
from movement.kinematics import compute_displacement
10+
from movement.utils.logging import log_to_attrs
1011
from movement.utils.reports import report_nan_values
12+
from movement.utils.vector import compute_norm
1113

1214

1315
@log_to_attrs
@@ -60,6 +62,92 @@ def filter_by_confidence(
6062
return data_filtered
6163

6264

65+
@log_to_attrs
66+
def filter_by_displacement(
67+
position: xr.DataArray,
68+
threshold: float = 10.0,
69+
direction: Literal["forward", "bidirectional"] = "forward",
70+
print_report: bool = False,
71+
) -> xr.DataArray:
72+
"""Filter data points based on displacement magnitude.
73+
74+
Frames in the ``position`` array that exceed a displacement magnitude
75+
threshold are set to NaN.
76+
77+
Two modes are supported via the ``direction`` parameter:
78+
- "forward" (default): A point at time ``t`` is set to NaN if it has
79+
moved more than the ``threshold`` Euclidean distance from the same
80+
point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``).
81+
- "bidirectional": A point at time ``t`` is set to NaN only if BOTH the
82+
displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to
83+
``t+1`` exceed the threshold (i.e., if
84+
``|pos(t) - pos(t-1)| > threshold`` AND
85+
``|pos(t+1) - pos(t)| > threshold``). This corresponds to the
86+
Stage 1 outlier detection described by the user.
87+
88+
Parameters
89+
----------
90+
position : xarray.DataArray
91+
The input data containing position information, with ``time``
92+
and ``space`` (in Cartesian coordinates) as required dimensions.
93+
threshold : float, optional
94+
The maximum Euclidean distance allowed for displacement.
95+
Defaults to 10.0.
96+
direction : Literal["forward", "bidirectional"], optional
97+
The directionality of the displacement check. Defaults to "forward".
98+
print_report : bool, optional
99+
Whether to print a report of the number of NaN values before and after
100+
filtering. Defaults to False.
101+
102+
Returns
103+
-------
104+
xr.DataArray
105+
The filtered position array, where points exceeding the displacement
106+
threshold condition have been set to NaN.
107+
108+
See Also
109+
--------
110+
movement.kinematics.compute_displacement:
111+
The function used to compute an array of displacement vectors.
112+
movement.utils.vector.compute_norm:
113+
The function used to compute distance as the magnitude of
114+
displacement vectors.
115+
116+
"""
117+
if not isinstance(position, xr.DataArray):
118+
raise TypeError("Input 'position' must be an xarray.DataArray.")
119+
120+
# Calculate forward displacement magnitude:
121+
# norm at time t = |pos(t) - pos(t-1)|
122+
displacement_fwd = compute_displacement(position)
123+
mag_fwd = compute_norm(displacement_fwd)
124+
125+
if direction == "forward":
126+
# Uni-directional: Keep if magnitude from t-1 to t is below threshold
127+
condition = mag_fwd < threshold
128+
elif direction == "bidirectional":
129+
# Bi-directional: Keep unless BOTH jump in and jump out are large.
130+
# Equivalent to: Keep if jump in is small OR jump out is small.
131+
# Calculate backward magnitude:
132+
# norm at t+1 related to t = |pos(t+1) - pos(t)|
133+
# mag_bwd[t] = mag_fwd[t+1]
134+
mag_bwd = mag_fwd.shift(time=-1, fill_value=0)
135+
condition = (mag_fwd < threshold) | (mag_bwd < threshold)
136+
else:
137+
raise ValueError(
138+
f"Invalid direction: {direction}. "
139+
f"Must be 'forward' or 'bidirectional'."
140+
)
141+
142+
position_filtered = position.where(condition)
143+
144+
if print_report:
145+
print(report_nan_values(position, "input"))
146+
print(report_nan_values(position_filtered, "output"))
147+
148+
return position_filtered
149+
150+
63151
@log_to_attrs
64152
def interpolate_over_time(
65153
data: xr.DataArray,
@@ -177,25 +265,35 @@ def rolling_filter(
177265
178266
"""
179267
half_window = window // 2
180-
data_windows = data.pad( # Pad the edges to avoid NaNs
268+
# Pad the edges to avoid NaNs before applying rolling window
269+
# Transpose ensures padding happens correctly regardless of dim order
270+
padded_data = data.transpose("time", ...).pad(
181271
time=half_window, mode="reflect"
182-
).rolling( # Take rolling windows across time
272+
)
273+
# Apply rolling window across time on padded data
274+
data_windows = padded_data.rolling(
183275
time=window, center=True, min_periods=min_periods
184276
)
185277

186278
# Compute the statistic over each window
187279
allowed_statistics = ["mean", "median", "max", "min"]
188280
if statistic not in allowed_statistics:
189-
raise log_error(
190-
ValueError,
281+
raise ValueError(
191282
f"Invalid statistic '{statistic}'. "
192-
f"Must be one of {allowed_statistics}.",
193-
)
283+
f"Must be one of {allowed_statistics}."
284+
) # <-- Corrected: Added closing parenthesis
194285

195286
data_rolled = getattr(data_windows, statistic)(skipna=True)
196287

197-
# Remove the padded edges
198-
data_rolled = data_rolled.isel(time=slice(half_window, -half_window))
288+
# Remove the padded edges by slicing
289+
# Ensure the slice matches the original time dimension size
290+
original_time_size = data.sizes["time"]
291+
data_rolled = data_rolled.isel(
292+
time=slice(half_window, half_window + original_time_size)
293+
)
294+
295+
# Transpose back to original dimension order
296+
data_rolled = data_rolled.transpose(*data.dims)
199297

200298
# Optional: Print NaN report
201299
if print_report:
@@ -256,15 +354,20 @@ def savgol_filter(
256354
257355
"""
258356
if "axis" in kwargs:
259-
raise log_error(
260-
ValueError, "The 'axis' argument may not be overridden."
261-
)
357+
raise ValueError("The 'axis' argument may not be overridden.")
262358
data_smoothed = data.copy()
359+
# Find the axis index corresponding to the 'time' dimension
360+
try:
361+
time_axis = data.dims.index("time")
362+
except ValueError as e:
363+
raise ValueError("Input data must have a 'time' dimension.") from e
364+
365+
# Apply savgol_filter along the identified time axis
263366
data_smoothed.values = signal.savgol_filter(
264-
data,
367+
data.values, # Pass numpy array to savgol_filter
265368
window,
266369
polyorder,
267-
axis=0,
370+
axis=time_axis,
268371
**kwargs,
269372
)
270373
if print_report:

0 commit comments

Comments
 (0)