Skip to content

Commit e6cc6af

Browse files
feat: add bidirectional option to filter_by_displacement
1 parent d2cd4e6 commit e6cc6af

File tree

1 file changed

+91
-6
lines changed

1 file changed

+91
-6
lines changed

movement/filtering.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import xarray as xr
77
from scipy import signal
88

9-
from movement.utils.logging import log_error, log_to_attrs
9+
from movement.kinematics import compute_displacement
10+
from movement.utils.logging import log_to_attrs
1011
from movement.utils.reports import report_nan_values
12+
from movement.utils.vector import compute_norm
1113

1214

1315
@log_to_attrs
@@ -60,6 +62,92 @@ def filter_by_confidence(
6062
return data_filtered
6163

6264

65+
@log_to_attrs
66+
def filter_by_displacement(
67+
position: xr.DataArray,
68+
threshold: float = 10.0,
69+
direction: Literal["forward", "bidirectional"] = "forward",
70+
print_report: bool = False,
71+
) -> xr.DataArray:
72+
"""Filter data points based on displacement magnitude.
73+
74+
Frames in the ``position`` array that exceed a displacement magnitude
75+
threshold are set to NaN.
76+
77+
Two modes are supported via the ``direction`` parameter:
78+
- "forward" (default): A point at time ``t`` is set to NaN if it has
79+
moved more than the ``threshold`` Euclidean distance from the same
80+
point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``).
81+
- "bidirectional": A point at time ``t`` is set to NaN only if BOTH the
82+
displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to
83+
``t+1`` exceed the threshold (i.e., if
84+
``|pos(t) - pos(t-1)| > threshold`` AND
85+
``|pos(t+1) - pos(t)| > threshold``). This corresponds to the
86+
Stage 1 outlier detection described by the user.
87+
88+
Parameters
89+
----------
90+
position : xr.DataArray
91+
The input data containing position information, with ``time``
92+
and ``space`` (in Cartesian coordinates) as required dimensions.
93+
threshold : float, optional
94+
The maximum Euclidean distance allowed for displacement.
95+
Defaults to 10.0.
96+
direction : Literal["forward", "bidirectional"], optional
97+
The directionality of the displacement check. Defaults to "forward".
98+
print_report : bool, optional
99+
Whether to print a report of the number of NaN values before and after
100+
filtering. Defaults to False.
101+
102+
Returns
103+
-------
104+
xr.DataArray
105+
The filtered position array, where points exceeding the displacement
106+
threshold condition have been set to NaN.
107+
108+
See Also
109+
--------
110+
movement.kinematics.compute_displacement:
111+
The function used to compute an array of displacement vectors.
112+
movement.utils.vector.compute_norm:
113+
The function used to compute distance as the magnitude of
114+
displacement vectors.
115+
116+
"""
117+
if not isinstance(position, xr.DataArray):
118+
raise TypeError("Input 'position' must be an xarray.DataArray.")
119+
120+
# Calculate forward displacement magnitude:
121+
# norm at time t = |pos(t) - pos(t-1)|
122+
displacement_fwd = compute_displacement(position)
123+
mag_fwd = compute_norm(displacement_fwd)
124+
125+
if direction == "forward":
126+
# Uni-directional: Keep if magnitude from t-1 to t is below threshold
127+
condition = mag_fwd < threshold
128+
elif direction == "bidirectional":
129+
# Bi-directional: Keep unless BOTH jump in and jump out are large.
130+
# Equivalent to: Keep if jump in is small OR jump out is small.
131+
# Calculate backward magnitude:
132+
# norm at t+1 related to t = |pos(t+1) - pos(t)|
133+
# mag_bwd[t] = mag_fwd[t+1]
134+
mag_bwd = mag_fwd.shift(time=-1, fill_value=0)
135+
condition = (mag_fwd < threshold) | (mag_bwd < threshold)
136+
else:
137+
raise ValueError(
138+
f"Invalid direction: {direction}. "
139+
f"Must be 'forward' or 'bidirectional'."
140+
)
141+
142+
position_filtered = position.where(condition)
143+
144+
if print_report:
145+
print(report_nan_values(position, "input"))
146+
print(report_nan_values(position_filtered, "output"))
147+
148+
return position_filtered
149+
150+
63151
@log_to_attrs
64152
def interpolate_over_time(
65153
data: xr.DataArray,
@@ -186,8 +274,7 @@ def rolling_filter(
186274
# Compute the statistic over each window
187275
allowed_statistics = ["mean", "median", "max", "min"]
188276
if statistic not in allowed_statistics:
189-
raise log_error(
190-
ValueError,
277+
raise ValueError(
191278
f"Invalid statistic '{statistic}'. "
192279
f"Must be one of {allowed_statistics}.",
193280
)
@@ -256,9 +343,7 @@ def savgol_filter(
256343
257344
"""
258345
if "axis" in kwargs:
259-
raise log_error(
260-
ValueError, "The 'axis' argument may not be overridden."
261-
)
346+
raise ValueError("The 'axis' argument may not be overridden.")
262347
data_smoothed = data.copy()
263348
data_smoothed.values = signal.savgol_filter(
264349
data,

0 commit comments

Comments
 (0)