|
6 | 6 | import xarray as xr |
7 | 7 | from scipy import signal |
8 | 8 |
|
9 | | -from movement.utils.logging import log_error, log_to_attrs |
| 9 | +from movement.kinematics import compute_displacement |
| 10 | +from movement.utils.logging import log_to_attrs |
10 | 11 | from movement.utils.reports import report_nan_values |
| 12 | +from movement.utils.vector import compute_norm |
11 | 13 |
|
12 | 14 |
|
13 | 15 | @log_to_attrs |
@@ -60,6 +62,92 @@ def filter_by_confidence( |
60 | 62 | return data_filtered |
61 | 63 |
|
62 | 64 |
|
| 65 | +@log_to_attrs |
| 66 | +def filter_by_displacement( |
| 67 | + position: xr.DataArray, |
| 68 | + threshold: float = 10.0, |
| 69 | + direction: Literal["forward", "bidirectional"] = "forward", |
| 70 | + print_report: bool = False, |
| 71 | +) -> xr.DataArray: |
| 72 | + """Filter data points based on displacement magnitude. |
| 73 | +
|
| 74 | + Frames in the ``position`` array that exceed a displacement magnitude |
| 75 | + threshold are set to NaN. |
| 76 | +
|
| 77 | + Two modes are supported via the ``direction`` parameter: |
| 78 | + - "forward" (default): A point at time ``t`` is set to NaN if it has |
| 79 | + moved more than the ``threshold`` Euclidean distance from the same |
| 80 | + point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``). |
| 81 | + - "bidirectional": A point at time ``t`` is set to NaN only if BOTH the |
| 82 | + displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to |
| 83 | + ``t+1`` exceed the threshold (i.e., if |
| 84 | + ``|pos(t) - pos(t-1)| > threshold`` AND |
| 85 | + ``|pos(t+1) - pos(t)| > threshold``). This corresponds to the |
| 86 | + Stage 1 outlier detection described by the user. |
| 87 | +
|
| 88 | + Parameters |
| 89 | + ---------- |
| 90 | + position : xarray.DataArray |
| 91 | + The input data containing position information, with ``time`` |
| 92 | + and ``space`` (in Cartesian coordinates) as required dimensions. |
| 93 | + threshold : float, optional |
| 94 | + The maximum Euclidean distance allowed for displacement. |
| 95 | + Defaults to 10.0. |
| 96 | + direction : Literal["forward", "bidirectional"], optional |
| 97 | + The directionality of the displacement check. Defaults to "forward". |
| 98 | + print_report : bool, optional |
| 99 | + Whether to print a report of the number of NaN values before and after |
| 100 | + filtering. Defaults to False. |
| 101 | +
|
| 102 | + Returns |
| 103 | + ------- |
| 104 | + xr.DataArray |
| 105 | + The filtered position array, where points exceeding the displacement |
| 106 | + threshold condition have been set to NaN. |
| 107 | +
|
| 108 | + See Also |
| 109 | + -------- |
| 110 | + movement.kinematics.compute_displacement: |
| 111 | + The function used to compute an array of displacement vectors. |
| 112 | + movement.utils.vector.compute_norm: |
| 113 | + The function used to compute distance as the magnitude of |
| 114 | + displacement vectors. |
| 115 | +
|
| 116 | + """ |
| 117 | + if not isinstance(position, xr.DataArray): |
| 118 | + raise TypeError("Input 'position' must be an xarray.DataArray.") |
| 119 | + |
| 120 | + # Calculate forward displacement magnitude: |
| 121 | + # norm at time t = |pos(t) - pos(t-1)| |
| 122 | + displacement_fwd = compute_displacement(position) |
| 123 | + mag_fwd = compute_norm(displacement_fwd) |
| 124 | + |
| 125 | + if direction == "forward": |
| 126 | + # Uni-directional: Keep if magnitude from t-1 to t is below threshold |
| 127 | + condition = mag_fwd < threshold |
| 128 | + elif direction == "bidirectional": |
| 129 | + # Bi-directional: Keep unless BOTH jump in and jump out are large. |
| 130 | + # Equivalent to: Keep if jump in is small OR jump out is small. |
| 131 | + # Calculate backward magnitude: |
| 132 | + # norm at t+1 related to t = |pos(t+1) - pos(t)| |
| 133 | + # mag_bwd[t] = mag_fwd[t+1] |
| 134 | + mag_bwd = mag_fwd.shift(time=-1, fill_value=0) |
| 135 | + condition = (mag_fwd < threshold) | (mag_bwd < threshold) |
| 136 | + else: |
| 137 | + raise ValueError( |
| 138 | + f"Invalid direction: {direction}. " |
| 139 | + f"Must be 'forward' or 'bidirectional'." |
| 140 | + ) |
| 141 | + |
| 142 | + position_filtered = position.where(condition) |
| 143 | + |
| 144 | + if print_report: |
| 145 | + print(report_nan_values(position, "input")) |
| 146 | + print(report_nan_values(position_filtered, "output")) |
| 147 | + |
| 148 | + return position_filtered |
| 149 | + |
| 150 | + |
63 | 151 | @log_to_attrs |
64 | 152 | def interpolate_over_time( |
65 | 153 | data: xr.DataArray, |
@@ -177,25 +265,35 @@ def rolling_filter( |
177 | 265 |
|
178 | 266 | """ |
179 | 267 | half_window = window // 2 |
180 | | - data_windows = data.pad( # Pad the edges to avoid NaNs |
| 268 | + # Pad the edges to avoid NaNs before applying rolling window |
| 269 | + # Transpose ensures padding happens correctly regardless of dim order |
| 270 | + padded_data = data.transpose("time", ...).pad( |
181 | 271 | time=half_window, mode="reflect" |
182 | | - ).rolling( # Take rolling windows across time |
| 272 | + ) |
| 273 | + # Apply rolling window across time on padded data |
| 274 | + data_windows = padded_data.rolling( |
183 | 275 | time=window, center=True, min_periods=min_periods |
184 | 276 | ) |
185 | 277 |
|
186 | 278 | # Compute the statistic over each window |
187 | 279 | allowed_statistics = ["mean", "median", "max", "min"] |
188 | 280 | if statistic not in allowed_statistics: |
189 | | - raise log_error( |
190 | | - ValueError, |
| 281 | + raise ValueError( |
191 | 282 | f"Invalid statistic '{statistic}'. " |
192 | | - f"Must be one of {allowed_statistics}.", |
193 | | - ) |
| 283 | + f"Must be one of {allowed_statistics}." |
| 284 | + ) # <-- Corrected: Added closing parenthesis |
194 | 285 |
|
195 | 286 | data_rolled = getattr(data_windows, statistic)(skipna=True) |
196 | 287 |
|
197 | | - # Remove the padded edges |
198 | | - data_rolled = data_rolled.isel(time=slice(half_window, -half_window)) |
| 288 | + # Remove the padded edges by slicing |
| 289 | + # Ensure the slice matches the original time dimension size |
| 290 | + original_time_size = data.sizes["time"] |
| 291 | + data_rolled = data_rolled.isel( |
| 292 | + time=slice(half_window, half_window + original_time_size) |
| 293 | + ) |
| 294 | + |
| 295 | + # Transpose back to original dimension order |
| 296 | + data_rolled = data_rolled.transpose(*data.dims) |
199 | 297 |
|
200 | 298 | # Optional: Print NaN report |
201 | 299 | if print_report: |
@@ -256,15 +354,20 @@ def savgol_filter( |
256 | 354 |
|
257 | 355 | """ |
258 | 356 | if "axis" in kwargs: |
259 | | - raise log_error( |
260 | | - ValueError, "The 'axis' argument may not be overridden." |
261 | | - ) |
| 357 | + raise ValueError("The 'axis' argument may not be overridden.") |
262 | 358 | data_smoothed = data.copy() |
| 359 | + # Find the axis index corresponding to the 'time' dimension |
| 360 | + try: |
| 361 | + time_axis = data.dims.index("time") |
| 362 | + except ValueError as e: |
| 363 | + raise ValueError("Input data must have a 'time' dimension.") from e |
| 364 | + |
| 365 | + # Apply savgol_filter along the identified time axis |
263 | 366 | data_smoothed.values = signal.savgol_filter( |
264 | | - data, |
| 367 | + data.values, # Pass numpy array to savgol_filter |
265 | 368 | window, |
266 | 369 | polyorder, |
267 | | - axis=0, |
| 370 | + axis=time_axis, |
268 | 371 | **kwargs, |
269 | 372 | ) |
270 | 373 | if print_report: |
|
0 commit comments