From 11b8203a0f99954735c075985624230eac327ae7 Mon Sep 17 00:00:00 2001 From: demoncoder-crypto Date: Tue, 25 Mar 2025 15:50:04 -0400 Subject: [PATCH 1/2] Add track fusion module for combining multiple tracking sources --- examples/fuse_multiple_tracks.py | 266 +++++++++++ movement/__init__.py | 3 + movement/track_fusion.py | 664 +++++++++++++++++++++++++++ tests/test_unit/test_track_fusion.py | 231 ++++++++++ 4 files changed, 1164 insertions(+) create mode 100644 examples/fuse_multiple_tracks.py create mode 100644 movement/track_fusion.py create mode 100644 tests/test_unit/test_track_fusion.py diff --git a/examples/fuse_multiple_tracks.py b/examples/fuse_multiple_tracks.py new file mode 100644 index 000000000..8ac43d97f --- /dev/null +++ b/examples/fuse_multiple_tracks.py @@ -0,0 +1,266 @@ +"""Fuse multiple tracking sources +============================ + +Demonstrate how to combine tracking data from multiple sources to produce a more +accurate trajectory. This is particularly useful in cases where different tracking +methods may fail in different situations, such as with ID swaps. +""" + +# %% +# Imports +# ------- + +from matplotlib import pyplot as plt + +from movement import sample_data +from movement.io import load_poses +from movement.plots import plot_centroid_trajectory +from movement.track_fusion import fuse_tracks + +# %% +# Load sample datasets +# ------------------- +# We'll load the DeepLabCut and SLEAP data for the same mouse in an EPM (Elevated Plus Maze) +# experiment. The DeepLabCut data is considered more reliable, while the SLEAP data was +# generated using a model trained on less data. + +# DeepLabCut data (considered more reliable) +dlc_path = sample_data.fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")["poses"] +ds_dlc = load_poses.from_dlc_file(dlc_path, fps=30) + +# SLEAP data (considered less reliable) +sleap_path = sample_data.fetch_dataset_paths("SLEAP_single-mouse_EPM.analysis.h5")["poses"] +ds_sleap = load_poses.from_sleap_file(sleap_path, fps=30) + +# %% +# Inspect the datasets +# ------------------- +# Let's look at the available keypoints in each dataset. + +print("DeepLabCut keypoints:", ds_dlc.keypoints.values) +print("SLEAP keypoints:", ds_sleap.keypoints.values) + +# %% +# The two datasets might have different keypoints, so we'll focus on the centroid. +# If "centroid" doesn't exist in one of the datasets, we would need to compute it +# from other keypoints or choose a different keypoint common to both datasets. + +# %% +# Visualize the tracking from the individual sources +# ------------------------------------------------- +# First let's plot the centroid trajectory from both sources separately. + +# Create a figure with two subplots side by side +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# Plot DLC trajectory +plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) +axes[0].set_title('DeepLabCut Tracking') +axes[0].invert_yaxis() # Invert y-axis to match image coordinates + +# Plot SLEAP trajectory +plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) +axes[1].set_title('SLEAP Tracking') +axes[1].invert_yaxis() + +fig.tight_layout() + +# %% +# Fuse tracks using different methods +# ----------------------------------- +# Now we'll combine the tracks using different fusion methods and compare the results. + +# List of methods to try +methods = ["mean", "median", "weighted", "reliability", "kalman"] + +# Create figure with 3 subplots (3 rows, 2 columns) +fig, axes = plt.subplots(3, 2, figsize=(12, 15)) +axes = axes.flatten() + +# Plot the original tracks in the first two subplots +plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) +axes[0].set_title('Original: DeepLabCut') +axes[0].invert_yaxis() + +plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) +axes[1].set_title('Original: SLEAP') +axes[1].invert_yaxis() + +# Fuse and plot the tracks with different methods +for i, method in enumerate(methods, 2): + if i < len(axes): + # Set weights for weighted method (example: 0.7 for DLC, 0.3 for SLEAP) + weights = [0.7, 0.3] if method == "weighted" else None + + # Fuse the tracks + fused_track = fuse_tracks( + datasets=[ds_dlc, ds_sleap], + method=method, + keypoint="centroid", + weights=weights, + print_report=True + ) + + # Plot the fused track + plot_centroid_trajectory(fused_track, ax=axes[i]) + axes[i].set_title(f'Fused: {method.capitalize()}') + axes[i].invert_yaxis() + +fig.tight_layout() + +# %% +# Detailed Comparison: Kalman Filter Fusion +# ---------------------------------------- +# Let's take a closer look at the Kalman filter method, which often provides +# good results for trajectory data. + +# Create a new figure +plt.figure(figsize=(10, 8)) + +# Fuse tracks with Kalman filter +kalman_fused = fuse_tracks( + datasets=[ds_dlc, ds_sleap], + method="kalman", + keypoint="centroid", + process_noise_scale=0.01, # Controls smoothness of trajectory + measurement_noise_scales=[0.1, 0.3], # Lower values for more reliable sources + print_report=True +) + +# Plot trajectories from both sources and the fused result +plt.plot( + ds_dlc.position.sel(keypoints="centroid", space="x"), + ds_dlc.position.sel(keypoints="centroid", space="y"), + 'b.', alpha=0.5, label='DeepLabCut' +) +plt.plot( + ds_sleap.position.sel(keypoints="centroid", space="x"), + ds_sleap.position.sel(keypoints="centroid", space="y"), + 'g.', alpha=0.5, label='SLEAP' +) +plt.plot( + kalman_fused.sel(space="x"), + kalman_fused.sel(space="y"), + 'r-', linewidth=2, label='Kalman Fused' +) + +plt.gca().invert_yaxis() +plt.grid(True, alpha=0.3) +plt.legend() +plt.title('Comparison of Original Tracks and Kalman-Fused Track') +plt.xlabel('X Position') +plt.ylabel('Y Position') + +# %% +# Temporal Analysis: Plotting Coordinate Values Over Time +# ------------------------------------------------------ +# Let's look at how the x-coordinate values change over time for the different sources. + +# Create a new figure +plt.figure(figsize=(12, 6)) + +# Plot x-coordinate over time +time_values = kalman_fused.time.values + +plt.plot( + time_values, + ds_dlc.position.sel(keypoints="centroid", space="x"), + 'b-', alpha=0.5, label='DeepLabCut' +) +plt.plot( + time_values, + ds_sleap.position.sel(keypoints="centroid", space="x"), + 'g-', alpha=0.5, label='SLEAP' +) +plt.plot( + time_values, + kalman_fused.sel(space="x"), + 'r-', linewidth=2, label='Kalman Fused' +) + +plt.grid(True, alpha=0.3) +plt.legend() +plt.title('X-Coordinate Values Over Time') +plt.xlabel('Time') +plt.ylabel('X Position') + +# %% +# Multiple-Animal Tracking Example with Potential ID Swaps +# ------------------------------------------------------- +# Now let's look at a more complex example with multiple animals, +# where ID swaps might be an issue. +# For this, we'll use the SLEAP datasets for three mice. + +# Load the two SLEAP datasets with three mice +ds_proofread = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" +) +ds_mixed = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5" +) + +print("Proofread dataset individuals:", ds_proofread.individuals.values) +print("Mixed-labels dataset individuals:", ds_mixed.individuals.values) + +# %% +# For each individual in the dataset, fuse the tracks from both sources + +# Create a figure for comparing original and fused tracks +fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + +# Flatten axes for easier iteration +axes = axes.flatten() + +# Plot the original tracks for each mouse in the first row +for i, individual in enumerate(ds_proofread.individuals.values): + if i < 3: # First row + # Plot original trajectory from proofread dataset (more reliable) + pos = ds_proofread.position.sel(individuals=individual) + plot_centroid_trajectory(pos, ax=axes[i]) + axes[i].set_title(f'Original: {individual}') + axes[i].invert_yaxis() + +# Fuse and plot the tracks for each mouse in the second row +for i, individual in enumerate(ds_proofread.individuals.values): + if i < 3: # We have 3 mice + # Get the individual datasets + individual_ds_proofread = ds_proofread.sel(individuals=individual) + individual_ds_mixed = ds_mixed.sel(individuals=individual) + + # Fuse the tracks with the Kalman filter (can be replaced with other methods) + fused_track = fuse_tracks( + datasets=[individual_ds_proofread, individual_ds_mixed], + method="kalman", + keypoint="centroid", + # More weight to the proofread dataset (considered more reliable) + measurement_noise_scales=[0.1, 0.3], + print_report=False + ) + + # Plot the fused track + plot_centroid_trajectory(fused_track, ax=axes[i+3]) + axes[i+3].set_title(f'Fused: {individual}') + axes[i+3].invert_yaxis() + +fig.tight_layout() + +# %% +# Conclusions +# ---------- +# We've demonstrated several methods for combining tracking data from multiple sources: +# +# 1. **Mean**: Simple averaging of all valid measurements. +# 2. **Median**: More robust to outliers than the mean. +# 3. **Weighted**: Weighted average based on source reliability. +# 4. **Reliability-based**: Selects the most reliable source at each time point. +# 5. **Kalman filter**: Probabilistic approach that models position and velocity. +# +# The Kalman filter often provides the best results as it can: +# - Handle noisy measurements from multiple sources +# - Model the dynamics of movement (position and velocity) +# - Provide smooth trajectories that follow physical constraints +# - Handle missing data and uncertainty in measurements +# +# For multi-animal tracking with potential ID swaps, track fusion can be particularly +# valuable. By combining information from different tracking methods that may fail in +# different situations, we can produce more accurate trajectories across time. \ No newline at end of file diff --git a/movement/__init__.py b/movement/__init__.py index bf5d4a2d2..1ec99216f 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -15,3 +15,6 @@ # initialize logger upon import configure_logging() + +# Make track fusion functionality available at the top level +from movement.track_fusion import fuse_tracks diff --git a/movement/track_fusion.py b/movement/track_fusion.py new file mode 100644 index 000000000..ae69db2c5 --- /dev/null +++ b/movement/track_fusion.py @@ -0,0 +1,664 @@ +"""Combine tracking data from multiple sources. + +This module provides functions for combining tracking data from multiple +sources to produce more accurate trajectories. This is particularly useful +in cases where different tracking methods may fail in different situations, +such as in multi-animal tracking with ID swaps. +""" + +import logging +from enum import Enum, auto +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from scipy.signal import medfilt + +from movement.filtering import interpolate_over_time, rolling_filter +from movement.utils.logging import log_error, log_to_attrs +from movement.utils.reports import report_nan_values +from movement.validators.arrays import validate_dims_coords + +logger = logging.getLogger(__name__) + + +class FusionMethod(Enum): + """Enumeration of available track fusion methods.""" + + MEAN = auto() + MEDIAN = auto() + WEIGHTED = auto() + RELIABILITY_BASED = auto() + KALMAN = auto() + + +@log_to_attrs +def align_datasets( + datasets: List[xr.Dataset], + keypoint: str = "centroid", + interpolate: bool = True, + max_gap: Optional[int] = 5, +) -> List[xr.DataArray]: + """Aligns multiple datasets to have the same time coordinates. + + Parameters + ---------- + datasets : list of xarray.Dataset + List of datasets containing position data to align. + keypoint : str, optional + The keypoint to extract from each dataset, by default "centroid". + interpolate : bool, optional + Whether to interpolate missing values after alignment, by default True. + max_gap : int, optional + Maximum size of gap to interpolate, by default 5. + + Returns + ------- + list of xarray.DataArray + List of aligned DataArrays containing only the specified keypoint position data. + + Notes + ----- + This function extracts the specified keypoint from each dataset, aligns them to + have the same time coordinates, and optionally interpolates missing values. + """ + if not datasets: + raise log_error(ValueError, "No datasets provided") + + # Extract the keypoint position data from each dataset + position_arrays = [] + for ds in datasets: + # Check if keypoint exists in this dataset + if "keypoints" in ds.dims and keypoint not in ds.keypoints.values: + available_keypoints = list(ds.keypoints.values) + raise log_error( + ValueError, + f"Keypoint '{keypoint}' not found in dataset. " + f"Available keypoints: {available_keypoints}", + ) + + # Extract position for this keypoint + if "keypoints" in ds.dims: + pos = ds.position.sel(keypoints=keypoint) + else: + # Handle datasets without keypoints dimension + pos = ds.position + + position_arrays.append(pos) + + # Get union of all time coordinates + all_times = sorted(set().union(*[set(arr.time.values) for arr in position_arrays])) + + # Reindex all arrays to the common time coordinate + aligned_arrays = [] + for arr in position_arrays: + reindexed = arr.reindex(time=all_times) + + # Optionally interpolate missing values + if interpolate: + reindexed = interpolate_over_time(reindexed, max_gap=max_gap) + + aligned_arrays.append(reindexed) + + return aligned_arrays + + +@log_to_attrs +def fuse_tracks_mean( + aligned_tracks: List[xr.DataArray], + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by taking the mean across all sources. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values averaged across sources. + + Notes + ----- + This function computes the mean of all valid position values at each time point. + If all sources have NaN at a particular time point, the result will also be NaN. + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # Take the mean along the source dimension, ignoring NaNs + fused = stacked.mean(dim="source", skipna=True) + + if print_report: + print(report_nan_values(fused, "Fused track (mean)")) + + return fused + + +@log_to_attrs +def fuse_tracks_median( + aligned_tracks: List[xr.DataArray], + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by taking the median across all sources. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values being the median across sources. + + Notes + ----- + This function computes the median of all valid position values at each time point. + If all sources have NaN at a particular time point, the result will also be NaN. + This method is more robust to outliers than the mean method. + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # Take the median along the source dimension, ignoring NaNs + fused = stacked.median(dim="source", skipna=True) + + if print_report: + print(report_nan_values(fused, "Fused track (median)")) + + return fused + + +@log_to_attrs +def fuse_tracks_weighted( + aligned_tracks: List[xr.DataArray], + weights: List[float] = None, + confidence_arrays: List[xr.DataArray] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks using a weighted average. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + weights : list of float, optional + Static weights for each track source. Must sum to 1 if provided. + If not provided and confidence_arrays is also None, equal weights are used. + confidence_arrays : list of xarray.DataArray, optional + Dynamic confidence values for each track. Must match the shape of aligned_tracks. + If provided, these are used instead of static weights. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values weighted by the specified weights or confidence values. + + Notes + ----- + This function computes a weighted average of position values. Weights can be either: + - Static (one weight per source) + - Dynamic (confidence value for each position at each time point) + If both weights and confidence_arrays are provided, confidence_arrays takes precedence. + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + n_tracks = len(aligned_tracks) + + # Check and prepare weights + if weights is not None: + if len(weights) != n_tracks: + raise log_error( + ValueError, + f"Number of weights ({len(weights)}) does not match " + f"number of tracks ({n_tracks})" + ) + if abs(sum(weights) - 1.0) > 1e-10: + raise log_error( + ValueError, + f"Weights must sum to 1, got sum={sum(weights)}" + ) + else: + # Equal weights if nothing is provided + weights = [1.0 / n_tracks] * n_tracks + + # Use dynamic confidence arrays if provided + if confidence_arrays is not None: + if len(confidence_arrays) != n_tracks: + raise log_error( + ValueError, + f"Number of confidence arrays ({len(confidence_arrays)}) does not match " + f"number of tracks ({n_tracks})" + ) + + # Normalize confidence values per time point + # Stack all confidence arrays along a 'source' dimension + stacked_conf = xr.concat(confidence_arrays, dim="source") + + # Calculate sum of confidences at each time point + sum_conf = stacked_conf.sum(dim="source") + + # Handle zeros by replacing with equal weights + has_zeros = (sum_conf == 0) + norm_conf = stacked_conf / sum_conf + norm_conf = norm_conf.where(~has_zeros, 1.0 / n_tracks) + + # Apply confidence-weighted average + stacked_pos = xr.concat(aligned_tracks, dim="source") + weighted_pos = stacked_pos * norm_conf + fused = weighted_pos.sum(dim="source", skipna=True) + + else: + # Apply static weights + weighted_tracks = [track * weight for track, weight in zip(aligned_tracks, weights)] + + # Stack and sum along a new 'source' dimension + stacked = xr.concat(weighted_tracks, dim="source") + + # Calculate where all tracks are NaN + all_nan = xr.concat([track.isnull() for track in aligned_tracks], dim="source").all(dim="source") + + # Sum along source dimension, set result to NaN where all sources are NaN + fused = stacked.sum(dim="source", skipna=True).where(~all_nan) + + if print_report: + print(report_nan_values(fused, "Fused track (weighted average)")) + + return fused + + +@log_to_attrs +def fuse_tracks_reliability( + aligned_tracks: List[xr.DataArray], + reliability_metrics: List[float] = None, + window_size: int = 11, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by selecting the most reliable source at each time point. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + reliability_metrics : list of float, optional + Global reliability score for each source (higher is better). + If not provided, NaN count is used as an inverse reliability metric. + window_size : int, optional + Window size for filtering the selection of sources, by default 11. + Must be an odd number. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values taken from the most reliable source at each time. + + Notes + ----- + This function selects values from the most reliable source at each time point, + then applies a median filter to avoid rapid switching between sources, which + could create unrealistic jumps in the trajectory. + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + if window_size % 2 == 0: + raise log_error(ValueError, "Window size must be an odd number") + + n_tracks = len(aligned_tracks) + + # Determine track reliability if not provided + if reliability_metrics is None: + # Count NaNs in each track (fewer NaNs = more reliable) + nan_counts = [float(track.isnull().sum().values) for track in aligned_tracks] + total_values = float(aligned_tracks[0].size) + # Convert to a reliability score (inverse of NaN proportion) + reliability_metrics = [1.0 - (count / total_values) for count in nan_counts] + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # For each time point, create a selection array based on reliability and NaN status + time_points = stacked.time.values + selected_sources = np.zeros(len(time_points), dtype=int) + + # Loop through each time point + for i, t in enumerate(time_points): + values_at_t = [track.sel(time=t).values for track in aligned_tracks] + is_nan = [np.isnan(val).any() for val in values_at_t] + + # If all sources have NaN, pick the most reliable one anyway + if all(is_nan): + selected_sources[i] = np.argmax(reliability_metrics) + else: + # Filter out NaN sources + valid_indices = [idx for idx, nan_status in enumerate(is_nan) if not nan_status] + valid_reliability = [reliability_metrics[idx] for idx in valid_indices] + + # Select the most reliable valid source + best_valid_idx = valid_indices[np.argmax(valid_reliability)] + selected_sources[i] = best_valid_idx + + # Apply median filter to smooth source selection and avoid rapid switching + if window_size > 1 and len(time_points) > window_size: + selected_sources = medfilt(selected_sources, window_size) + + # Create the fused track by selecting values from the chosen source at each time + fused_data = np.zeros((len(time_points), stacked.sizes["space"])) + + for i, (t, source_idx) in enumerate(zip(time_points, selected_sources)): + fused_data[i] = stacked.sel(time=t, source=source_idx).values + + # Create a new DataArray with the fused data + fused = xr.DataArray( + data=fused_data, + dims=["time", "space"], + coords={ + "time": time_points, + "space": stacked.space.values + } + ) + + if print_report: + print(report_nan_values(fused, "Fused track (reliability-based)")) + + return fused + + +@log_to_attrs +def fuse_tracks_kalman( + aligned_tracks: List[xr.DataArray], + process_noise_scale: float = 0.01, + measurement_noise_scales: List[float] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks using a Kalman filter. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + process_noise_scale : float, optional + Scale factor for the process noise covariance, by default 0.01. + measurement_noise_scales : list of float, optional + Scale factors for measurement noise for each source. + Lower values indicate more reliable sources. Default is equal values. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values estimated by the Kalman filter. + + Notes + ----- + This function implements a simple Kalman filter for track fusion. The filter: + 1. Models position and velocity in a state vector + 2. Predicts the next state based on constant velocity assumptions + 3. Updates the prediction using measurements from all available sources + 4. Handles missing measurements (NaNs) by skipping the update step + + The Kalman filter is particularly effective for trajectory smoothing and + handling noisy measurements from multiple sources. + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + n_tracks = len(aligned_tracks) + + # Set default measurement noise scales if not provided + if measurement_noise_scales is None: + measurement_noise_scales = [1.0] * n_tracks + + if len(measurement_noise_scales) != n_tracks: + raise log_error( + ValueError, + f"Number of measurement noise scales ({len(measurement_noise_scales)}) " + f"does not match number of tracks ({n_tracks})" + ) + + # Get the common time axis + time_points = aligned_tracks[0].time.values + n_timesteps = len(time_points) + + # Get the dimensionality of the space (2D or 3D) + n_dims = len(aligned_tracks[0].space.values) + + # Initialize state vector [x, y, vx, vy] or [x, y, z, vx, vy, vz] + state_dim = 2 * n_dims + state = np.zeros(state_dim) + + # Initialize state covariance matrix + state_cov = np.eye(state_dim) + + # Define transition matrix (constant velocity model) + dt = 1.0 # Assuming unit time steps + A = np.eye(state_dim) + for i in range(n_dims): + A[i, i + n_dims] = dt + + # Define process noise covariance + Q = np.eye(state_dim) * process_noise_scale + + # Define measurement matrix (extracts position from state) + H = np.zeros((n_dims, state_dim)) + for i in range(n_dims): + H[i, i] = 1.0 + + # Initialize storage for Kalman filter output + kalman_output = np.zeros((n_timesteps, n_dims)) + + # For the first time step, initialize with the average of available measurements + first_measurements = [] + for track in aligned_tracks: + pos = track.sel(time=time_points[0]).values + if not np.isnan(pos).any(): + first_measurements.append(pos) + + if first_measurements: + initial_pos = np.mean(first_measurements, axis=0) + state[:n_dims] = initial_pos + kalman_output[0] = initial_pos + + # Run Kalman filter + for t in range(1, n_timesteps): + # Prediction step + state = A @ state + state_cov = A @ state_cov @ A.T + Q + + # Update step - combine all available measurements + measurements = [] + R_list = [] # Measurement noise covariances + + for i, track in enumerate(aligned_tracks): + pos = track.sel(time=time_points[t]).values + if not np.isnan(pos).any(): + measurements.append(pos) + # Measurement noise covariance for this source + R = np.eye(n_dims) * measurement_noise_scales[i] + R_list.append(R) + + # Skip update if no measurements available + if not measurements: + kalman_output[t] = state[:n_dims] + continue + + # Apply update for each measurement + for z, R in zip(measurements, R_list): + y = z - H @ state # Measurement residual + S = H @ state_cov @ H.T + R # Residual covariance + K = state_cov @ H.T @ np.linalg.inv(S) # Kalman gain + state = state + K @ y # Updated state + state_cov = (np.eye(state_dim) - K @ H) @ state_cov # Updated covariance + + # Store the updated position + kalman_output[t] = state[:n_dims] + + # Create a new DataArray with the Kalman filter output + fused = xr.DataArray( + data=kalman_output, + dims=["time", "space"], + coords={ + "time": time_points, + "space": aligned_tracks[0].space.values + } + ) + + if print_report: + print(report_nan_values(fused, "Fused track (Kalman filter)")) + + return fused + + +@log_to_attrs +def fuse_tracks( + datasets: List[xr.Dataset], + method: Union[str, FusionMethod] = "kalman", + keypoint: str = "centroid", + interpolate_gaps: bool = True, + max_gap: int = 5, + weights: List[float] = None, + confidence_arrays: List[xr.DataArray] = None, + reliability_metrics: List[float] = None, + window_size: int = 11, + process_noise_scale: float = 0.01, + measurement_noise_scales: List[float] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks from multiple datasets using the specified method. + + Parameters + ---------- + datasets : list of xarray.Dataset + List of datasets containing position data to fuse. + method : str or FusionMethod, optional + Track fusion method to use, by default "kalman". Options are: + - "mean": Average position across all sources + - "median": Median position across all sources (robust to outliers) + - "weighted": Weighted average using static weights or confidence values + - "reliability": Select most reliable source at each time point + - "kalman": Apply Kalman filter to estimate the optimal trajectory + keypoint : str, optional + The keypoint to extract from each dataset, by default "centroid". + interpolate_gaps : bool, optional + Whether to interpolate missing values after alignment, by default True. + max_gap : int, optional + Maximum size of gap to interpolate, by default 5. + weights : list of float, optional + Static weights for each track source (used with "weighted" method). + confidence_arrays : list of xarray.DataArray, optional + Dynamic confidence values for each track (used with "weighted" method). + reliability_metrics : list of float, optional + Global reliability score for each source (used with "reliability" method). + window_size : int, optional + Window size for filtering source selection (used with "reliability" method). + process_noise_scale : float, optional + Scale factor for process noise (used with "kalman" method). + measurement_noise_scales : list of float, optional + Scale factors for measurement noise (used with "kalman" method). + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values determined by the specified fusion method. + + Raises + ------ + ValueError + If an unsupported fusion method is specified or parameters are invalid. + + Notes + ----- + This function acts as a high-level interface to various track fusion methods, + automatically handling dataset alignment and applying the selected fusion algorithm. + """ + # Convert string method to enum if needed + if isinstance(method, str): + method_map = { + "mean": FusionMethod.MEAN, + "median": FusionMethod.MEDIAN, + "weighted": FusionMethod.WEIGHTED, + "reliability": FusionMethod.RELIABILITY_BASED, + "kalman": FusionMethod.KALMAN, + } + + if method.lower() not in method_map: + valid_methods = list(method_map.keys()) + raise log_error( + ValueError, + f"Unsupported fusion method: {method}. " + f"Valid methods are: {valid_methods}" + ) + + method = method_map[method.lower()] + + # Align datasets + aligned_tracks = align_datasets( + datasets=datasets, + keypoint=keypoint, + interpolate=interpolate_gaps, + max_gap=max_gap, + ) + + # Apply fusion method + if method == FusionMethod.MEAN: + return fuse_tracks_mean( + aligned_tracks=aligned_tracks, + print_report=print_report, + ) + + elif method == FusionMethod.MEDIAN: + return fuse_tracks_median( + aligned_tracks=aligned_tracks, + print_report=print_report, + ) + + elif method == FusionMethod.WEIGHTED: + return fuse_tracks_weighted( + aligned_tracks=aligned_tracks, + weights=weights, + confidence_arrays=confidence_arrays, + print_report=print_report, + ) + + elif method == FusionMethod.RELIABILITY_BASED: + return fuse_tracks_reliability( + aligned_tracks=aligned_tracks, + reliability_metrics=reliability_metrics, + window_size=window_size, + print_report=print_report, + ) + + elif method == FusionMethod.KALMAN: + return fuse_tracks_kalman( + aligned_tracks=aligned_tracks, + process_noise_scale=process_noise_scale, + measurement_noise_scales=measurement_noise_scales, + print_report=print_report, + ) + + else: + raise log_error( + ValueError, + f"Unsupported fusion method: {method}" + ) \ No newline at end of file diff --git a/tests/test_unit/test_track_fusion.py b/tests/test_unit/test_track_fusion.py new file mode 100644 index 000000000..da0ffa8c7 --- /dev/null +++ b/tests/test_unit/test_track_fusion.py @@ -0,0 +1,231 @@ +"""Tests for track fusion functions.""" + +import numpy as np +import pytest +import xarray as xr + +from movement.track_fusion import ( + align_datasets, + fuse_tracks, + fuse_tracks_kalman, + fuse_tracks_mean, + fuse_tracks_median, + fuse_tracks_reliability, + fuse_tracks_weighted, +) + + +@pytest.fixture +def mock_datasets(): + """Create mock datasets for testing track fusion.""" + # Create two simple datasets with different time points and some NaNs + # Dataset 1: More reliable (fewer NaNs) + time1 = np.arange(0, 10, 1) + pos1 = np.zeros((10, 1, 2)) + # Simple straight line with a slope + pos1[:, 0, 0] = np.arange(0, 10, 1) # x coordinate + pos1[:, 0, 1] = np.arange(0, 10, 1) # y coordinate + # Add some NaNs + pos1[3, 0, :] = np.nan + + # Dataset 2: Less reliable (more NaNs) + time2 = np.arange(0, 10, 1) + pos2 = np.zeros((10, 1, 2)) + # Similar trajectory but with some noise + pos2[:, 0, 0] = np.arange(0, 10, 1) + np.random.normal(0, 0.5, 10) + pos2[:, 0, 1] = np.arange(0, 10, 1) + np.random.normal(0, 0.5, 10) + # Add more NaNs + pos2[3, 0, :] = np.nan + pos2[7, 0, :] = np.nan + + # Create xarray datasets + ds1 = xr.Dataset( + data_vars={ + "position": (["time", "keypoints", "space"], pos1), + "confidence": (["time", "keypoints"], np.ones((10, 1))), + }, + coords={ + "time": time1, + "keypoints": ["centroid"], + "space": ["x", "y"], + "individuals": ["individual_0"], + }, + ) + + ds2 = xr.Dataset( + data_vars={ + "position": (["time", "keypoints", "space"], pos2), + "confidence": (["time", "keypoints"], np.ones((10, 1))), + }, + coords={ + "time": time2, + "keypoints": ["centroid"], + "space": ["x", "y"], + "individuals": ["individual_0"], + }, + ) + + return [ds1, ds2] + + +def test_align_datasets(mock_datasets): + """Test aligning datasets with different time points.""" + aligned = align_datasets(mock_datasets, interpolate=False) + + # Check that both arrays have the same time coordinates + assert aligned[0].time.equals(aligned[1].time) + + # Check that NaNs are preserved when interpolate=False + assert np.isnan(aligned[0].sel(time=3, space="x").values) + assert np.isnan(aligned[1].sel(time=3, space="x").values) + assert np.isnan(aligned[1].sel(time=7, space="x").values) + + # Test with interpolation + aligned_interp = align_datasets(mock_datasets, interpolate=True) + + # Check that NaNs are interpolated + assert not np.isnan(aligned_interp[0].sel(time=3, space="x").values) + assert not np.isnan(aligned_interp[1].sel(time=3, space="x").values) + assert not np.isnan(aligned_interp[1].sel(time=7, space="x").values) + + +def test_fuse_tracks_mean(mock_datasets): + """Test mean fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + fused = fuse_tracks_mean(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Check that the fused track has all time points + assert len(fused.time) == 10 + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + +def test_fuse_tracks_median(mock_datasets): + """Test median fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + fused = fuse_tracks_median(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + +def test_fuse_tracks_weighted(mock_datasets): + """Test weighted fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + + # Test with static weights + weights = [0.7, 0.3] + fused = fuse_tracks_weighted(aligned, weights=weights) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + # Test with invalid weights (sum != 1) + with pytest.raises(ValueError): + fuse_tracks_weighted(aligned, weights=[0.5, 0.2]) + + # Test with mismatched weights length + with pytest.raises(ValueError): + fuse_tracks_weighted(aligned, weights=[0.5, 0.3, 0.2]) + + +def test_fuse_tracks_reliability(mock_datasets): + """Test reliability-based fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=False) # Keep NaNs for testing + + # Test with automatic reliability metrics + fused = fuse_tracks_reliability(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Test with custom reliability metrics + reliability_metrics = [0.9, 0.5] # First source more reliable + fused = fuse_tracks_reliability(aligned, reliability_metrics=reliability_metrics) + + # Check that we still get a value for time point 7 where only source 1 has data + assert not np.isnan(fused.sel(time=7, space="x").values) + + # Test with invalid window size (even number) + with pytest.raises(ValueError): + fuse_tracks_reliability(aligned, window_size=10) + + +def test_fuse_tracks_kalman(mock_datasets): + """Test Kalman filter fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=False) # Keep NaNs for testing + + # Test with default parameters + fused = fuse_tracks_kalman(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Kalman filter should interpolate over missing values + assert not np.isnan(fused).any() + + # Test with custom parameters + fused = fuse_tracks_kalman( + aligned, + process_noise_scale=0.1, + measurement_noise_scales=[0.1, 0.5] + ) + + # Check that we get a smoother trajectory (less variance) + x_vals = fused.sel(space="x").values + diff = np.diff(x_vals) + assert np.std(diff) < 0.5 # Standard deviation of the differences should be low + + # Test with mismatched noise scales length + with pytest.raises(ValueError): + fuse_tracks_kalman(aligned, measurement_noise_scales=[0.1, 0.2, 0.3]) + + +def test_fuse_tracks_high_level(mock_datasets): + """Test the high-level fuse_tracks interface.""" + # Test each method through the high-level interface + methods = ["mean", "median", "weighted", "reliability", "kalman"] + + for method in methods: + fused = fuse_tracks( + datasets=mock_datasets, + method=method, + keypoint="centroid", + interpolate_gaps=True + ) + + # Check output dimensions + assert "time" in fused.dims + assert "space" in fused.dims + assert len(fused.space) == 2 + + # No NaNs when interpolation is used + assert not np.isnan(fused).any() + + # Test with invalid method + with pytest.raises(ValueError): + fuse_tracks(mock_datasets, method="invalid_method") + + # Test with non-existent keypoint + with pytest.raises(ValueError): + fuse_tracks(mock_datasets, keypoint="non_existent") \ No newline at end of file From 7c15b1acdaea7d536148c019cea8ba48ef72f952 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Mar 2025 19:52:50 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/fuse_multiple_tracks.py | 83 +++++---- movement/track_fusion.py | 256 ++++++++++++++------------- tests/test_unit/test_track_fusion.py | 84 +++++---- 3 files changed, 231 insertions(+), 192 deletions(-) diff --git a/examples/fuse_multiple_tracks.py b/examples/fuse_multiple_tracks.py index 8ac43d97f..de93dc36e 100644 --- a/examples/fuse_multiple_tracks.py +++ b/examples/fuse_multiple_tracks.py @@ -25,11 +25,15 @@ # generated using a model trained on less data. # DeepLabCut data (considered more reliable) -dlc_path = sample_data.fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")["poses"] +dlc_path = sample_data.fetch_dataset_paths( + "DLC_single-mouse_EPM.predictions.h5" +)["poses"] ds_dlc = load_poses.from_dlc_file(dlc_path, fps=30) # SLEAP data (considered less reliable) -sleap_path = sample_data.fetch_dataset_paths("SLEAP_single-mouse_EPM.analysis.h5")["poses"] +sleap_path = sample_data.fetch_dataset_paths( + "SLEAP_single-mouse_EPM.analysis.h5" +)["poses"] ds_sleap = load_poses.from_sleap_file(sleap_path, fps=30) # %% @@ -55,12 +59,12 @@ # Plot DLC trajectory plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) -axes[0].set_title('DeepLabCut Tracking') +axes[0].set_title("DeepLabCut Tracking") axes[0].invert_yaxis() # Invert y-axis to match image coordinates # Plot SLEAP trajectory plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) -axes[1].set_title('SLEAP Tracking') +axes[1].set_title("SLEAP Tracking") axes[1].invert_yaxis() fig.tight_layout() @@ -79,11 +83,11 @@ # Plot the original tracks in the first two subplots plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) -axes[0].set_title('Original: DeepLabCut') +axes[0].set_title("Original: DeepLabCut") axes[0].invert_yaxis() plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) -axes[1].set_title('Original: SLEAP') +axes[1].set_title("Original: SLEAP") axes[1].invert_yaxis() # Fuse and plot the tracks with different methods @@ -91,19 +95,19 @@ if i < len(axes): # Set weights for weighted method (example: 0.7 for DLC, 0.3 for SLEAP) weights = [0.7, 0.3] if method == "weighted" else None - + # Fuse the tracks fused_track = fuse_tracks( datasets=[ds_dlc, ds_sleap], method=method, keypoint="centroid", weights=weights, - print_report=True + print_report=True, ) - + # Plot the fused track plot_centroid_trajectory(fused_track, ax=axes[i]) - axes[i].set_title(f'Fused: {method.capitalize()}') + axes[i].set_title(f"Fused: {method.capitalize()}") axes[i].invert_yaxis() fig.tight_layout() @@ -123,33 +127,42 @@ method="kalman", keypoint="centroid", process_noise_scale=0.01, # Controls smoothness of trajectory - measurement_noise_scales=[0.1, 0.3], # Lower values for more reliable sources - print_report=True + measurement_noise_scales=[ + 0.1, + 0.3, + ], # Lower values for more reliable sources + print_report=True, ) # Plot trajectories from both sources and the fused result plt.plot( ds_dlc.position.sel(keypoints="centroid", space="x"), ds_dlc.position.sel(keypoints="centroid", space="y"), - 'b.', alpha=0.5, label='DeepLabCut' + "b.", + alpha=0.5, + label="DeepLabCut", ) plt.plot( ds_sleap.position.sel(keypoints="centroid", space="x"), ds_sleap.position.sel(keypoints="centroid", space="y"), - 'g.', alpha=0.5, label='SLEAP' + "g.", + alpha=0.5, + label="SLEAP", ) plt.plot( kalman_fused.sel(space="x"), kalman_fused.sel(space="y"), - 'r-', linewidth=2, label='Kalman Fused' + "r-", + linewidth=2, + label="Kalman Fused", ) plt.gca().invert_yaxis() plt.grid(True, alpha=0.3) plt.legend() -plt.title('Comparison of Original Tracks and Kalman-Fused Track') -plt.xlabel('X Position') -plt.ylabel('Y Position') +plt.title("Comparison of Original Tracks and Kalman-Fused Track") +plt.xlabel("X Position") +plt.ylabel("Y Position") # %% # Temporal Analysis: Plotting Coordinate Values Over Time @@ -165,24 +178,30 @@ plt.plot( time_values, ds_dlc.position.sel(keypoints="centroid", space="x"), - 'b-', alpha=0.5, label='DeepLabCut' + "b-", + alpha=0.5, + label="DeepLabCut", ) plt.plot( time_values, ds_sleap.position.sel(keypoints="centroid", space="x"), - 'g-', alpha=0.5, label='SLEAP' + "g-", + alpha=0.5, + label="SLEAP", ) plt.plot( time_values, kalman_fused.sel(space="x"), - 'r-', linewidth=2, label='Kalman Fused' + "r-", + linewidth=2, + label="Kalman Fused", ) plt.grid(True, alpha=0.3) plt.legend() -plt.title('X-Coordinate Values Over Time') -plt.xlabel('Time') -plt.ylabel('X Position') +plt.title("X-Coordinate Values Over Time") +plt.xlabel("Time") +plt.ylabel("X Position") # %% # Multiple-Animal Tracking Example with Potential ID Swaps @@ -217,7 +236,7 @@ # Plot original trajectory from proofread dataset (more reliable) pos = ds_proofread.position.sel(individuals=individual) plot_centroid_trajectory(pos, ax=axes[i]) - axes[i].set_title(f'Original: {individual}') + axes[i].set_title(f"Original: {individual}") axes[i].invert_yaxis() # Fuse and plot the tracks for each mouse in the second row @@ -226,7 +245,7 @@ # Get the individual datasets individual_ds_proofread = ds_proofread.sel(individuals=individual) individual_ds_mixed = ds_mixed.sel(individuals=individual) - + # Fuse the tracks with the Kalman filter (can be replaced with other methods) fused_track = fuse_tracks( datasets=[individual_ds_proofread, individual_ds_mixed], @@ -234,13 +253,13 @@ keypoint="centroid", # More weight to the proofread dataset (considered more reliable) measurement_noise_scales=[0.1, 0.3], - print_report=False + print_report=False, ) - + # Plot the fused track - plot_centroid_trajectory(fused_track, ax=axes[i+3]) - axes[i+3].set_title(f'Fused: {individual}') - axes[i+3].invert_yaxis() + plot_centroid_trajectory(fused_track, ax=axes[i + 3]) + axes[i + 3].set_title(f"Fused: {individual}") + axes[i + 3].invert_yaxis() fig.tight_layout() @@ -263,4 +282,4 @@ # # For multi-animal tracking with potential ID swaps, track fusion can be particularly # valuable. By combining information from different tracking methods that may fail in -# different situations, we can produce more accurate trajectories across time. \ No newline at end of file +# different situations, we can produce more accurate trajectories across time. diff --git a/movement/track_fusion.py b/movement/track_fusion.py index ae69db2c5..1fe7f8786 100644 --- a/movement/track_fusion.py +++ b/movement/track_fusion.py @@ -8,16 +8,14 @@ import logging from enum import Enum, auto -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import xarray as xr from scipy.signal import medfilt -from movement.filtering import interpolate_over_time, rolling_filter +from movement.filtering import interpolate_over_time from movement.utils.logging import log_error, log_to_attrs from movement.utils.reports import report_nan_values -from movement.validators.arrays import validate_dims_coords logger = logging.getLogger(__name__) @@ -34,11 +32,11 @@ class FusionMethod(Enum): @log_to_attrs def align_datasets( - datasets: List[xr.Dataset], + datasets: list[xr.Dataset], keypoint: str = "centroid", interpolate: bool = True, - max_gap: Optional[int] = 5, -) -> List[xr.DataArray]: + max_gap: int | None = 5, +) -> list[xr.DataArray]: """Aligns multiple datasets to have the same time coordinates. Parameters @@ -61,6 +59,7 @@ def align_datasets( ----- This function extracts the specified keypoint from each dataset, aligns them to have the same time coordinates, and optionally interpolates missing values. + """ if not datasets: raise log_error(ValueError, "No datasets provided") @@ -76,36 +75,38 @@ def align_datasets( f"Keypoint '{keypoint}' not found in dataset. " f"Available keypoints: {available_keypoints}", ) - + # Extract position for this keypoint if "keypoints" in ds.dims: pos = ds.position.sel(keypoints=keypoint) else: # Handle datasets without keypoints dimension pos = ds.position - + position_arrays.append(pos) # Get union of all time coordinates - all_times = sorted(set().union(*[set(arr.time.values) for arr in position_arrays])) - + all_times = sorted( + set().union(*[set(arr.time.values) for arr in position_arrays]) + ) + # Reindex all arrays to the common time coordinate aligned_arrays = [] for arr in position_arrays: reindexed = arr.reindex(time=all_times) - + # Optionally interpolate missing values if interpolate: reindexed = interpolate_over_time(reindexed, max_gap=max_gap) - + aligned_arrays.append(reindexed) - + return aligned_arrays @log_to_attrs def fuse_tracks_mean( - aligned_tracks: List[xr.DataArray], + aligned_tracks: list[xr.DataArray], print_report: bool = False, ) -> xr.DataArray: """Fuse tracks by taking the mean across all sources. @@ -126,25 +127,26 @@ def fuse_tracks_mean( ----- This function computes the mean of all valid position values at each time point. If all sources have NaN at a particular time point, the result will also be NaN. + """ if not aligned_tracks: raise log_error(ValueError, "No tracks provided") # Stack all tracks along a new 'source' dimension stacked = xr.concat(aligned_tracks, dim="source") - + # Take the mean along the source dimension, ignoring NaNs fused = stacked.mean(dim="source", skipna=True) - + if print_report: print(report_nan_values(fused, "Fused track (mean)")) - + return fused @log_to_attrs def fuse_tracks_median( - aligned_tracks: List[xr.DataArray], + aligned_tracks: list[xr.DataArray], print_report: bool = False, ) -> xr.DataArray: """Fuse tracks by taking the median across all sources. @@ -166,27 +168,28 @@ def fuse_tracks_median( This function computes the median of all valid position values at each time point. If all sources have NaN at a particular time point, the result will also be NaN. This method is more robust to outliers than the mean method. + """ if not aligned_tracks: raise log_error(ValueError, "No tracks provided") # Stack all tracks along a new 'source' dimension stacked = xr.concat(aligned_tracks, dim="source") - + # Take the median along the source dimension, ignoring NaNs fused = stacked.median(dim="source", skipna=True) - + if print_report: print(report_nan_values(fused, "Fused track (median)")) - + return fused @log_to_attrs def fuse_tracks_weighted( - aligned_tracks: List[xr.DataArray], - weights: List[float] = None, - confidence_arrays: List[xr.DataArray] = None, + aligned_tracks: list[xr.DataArray], + weights: list[float] = None, + confidence_arrays: list[xr.DataArray] = None, print_report: bool = False, ) -> xr.DataArray: """Fuse tracks using a weighted average. @@ -215,78 +218,83 @@ def fuse_tracks_weighted( - Static (one weight per source) - Dynamic (confidence value for each position at each time point) If both weights and confidence_arrays are provided, confidence_arrays takes precedence. + """ if not aligned_tracks: raise log_error(ValueError, "No tracks provided") - + n_tracks = len(aligned_tracks) - + # Check and prepare weights if weights is not None: if len(weights) != n_tracks: raise log_error( ValueError, f"Number of weights ({len(weights)}) does not match " - f"number of tracks ({n_tracks})" + f"number of tracks ({n_tracks})", ) if abs(sum(weights) - 1.0) > 1e-10: raise log_error( - ValueError, - f"Weights must sum to 1, got sum={sum(weights)}" + ValueError, f"Weights must sum to 1, got sum={sum(weights)}" ) else: # Equal weights if nothing is provided weights = [1.0 / n_tracks] * n_tracks - + # Use dynamic confidence arrays if provided if confidence_arrays is not None: if len(confidence_arrays) != n_tracks: raise log_error( ValueError, f"Number of confidence arrays ({len(confidence_arrays)}) does not match " - f"number of tracks ({n_tracks})" + f"number of tracks ({n_tracks})", ) - + # Normalize confidence values per time point # Stack all confidence arrays along a 'source' dimension stacked_conf = xr.concat(confidence_arrays, dim="source") - + # Calculate sum of confidences at each time point sum_conf = stacked_conf.sum(dim="source") - + # Handle zeros by replacing with equal weights - has_zeros = (sum_conf == 0) + has_zeros = sum_conf == 0 norm_conf = stacked_conf / sum_conf norm_conf = norm_conf.where(~has_zeros, 1.0 / n_tracks) - + # Apply confidence-weighted average stacked_pos = xr.concat(aligned_tracks, dim="source") weighted_pos = stacked_pos * norm_conf fused = weighted_pos.sum(dim="source", skipna=True) - + else: # Apply static weights - weighted_tracks = [track * weight for track, weight in zip(aligned_tracks, weights)] - + weighted_tracks = [ + track * weight + for track, weight in zip(aligned_tracks, weights, strict=False) + ] + # Stack and sum along a new 'source' dimension stacked = xr.concat(weighted_tracks, dim="source") - + # Calculate where all tracks are NaN - all_nan = xr.concat([track.isnull() for track in aligned_tracks], dim="source").all(dim="source") - + all_nan = xr.concat( + [track.isnull() for track in aligned_tracks], dim="source" + ).all(dim="source") + # Sum along source dimension, set result to NaN where all sources are NaN fused = stacked.sum(dim="source", skipna=True).where(~all_nan) - + if print_report: print(report_nan_values(fused, "Fused track (weighted average)")) - + return fused @log_to_attrs def fuse_tracks_reliability( - aligned_tracks: List[xr.DataArray], - reliability_metrics: List[float] = None, + aligned_tracks: list[xr.DataArray], + reliability_metrics: list[float] = None, window_size: int = 11, print_report: bool = False, ) -> xr.DataArray: @@ -315,78 +323,86 @@ def fuse_tracks_reliability( This function selects values from the most reliable source at each time point, then applies a median filter to avoid rapid switching between sources, which could create unrealistic jumps in the trajectory. + """ if not aligned_tracks: raise log_error(ValueError, "No tracks provided") - + if window_size % 2 == 0: raise log_error(ValueError, "Window size must be an odd number") - + n_tracks = len(aligned_tracks) - + # Determine track reliability if not provided if reliability_metrics is None: # Count NaNs in each track (fewer NaNs = more reliable) - nan_counts = [float(track.isnull().sum().values) for track in aligned_tracks] + nan_counts = [ + float(track.isnull().sum().values) for track in aligned_tracks + ] total_values = float(aligned_tracks[0].size) # Convert to a reliability score (inverse of NaN proportion) - reliability_metrics = [1.0 - (count / total_values) for count in nan_counts] - + reliability_metrics = [ + 1.0 - (count / total_values) for count in nan_counts + ] + # Stack all tracks along a new 'source' dimension stacked = xr.concat(aligned_tracks, dim="source") - + # For each time point, create a selection array based on reliability and NaN status time_points = stacked.time.values selected_sources = np.zeros(len(time_points), dtype=int) - + # Loop through each time point for i, t in enumerate(time_points): values_at_t = [track.sel(time=t).values for track in aligned_tracks] is_nan = [np.isnan(val).any() for val in values_at_t] - + # If all sources have NaN, pick the most reliable one anyway if all(is_nan): selected_sources[i] = np.argmax(reliability_metrics) else: # Filter out NaN sources - valid_indices = [idx for idx, nan_status in enumerate(is_nan) if not nan_status] - valid_reliability = [reliability_metrics[idx] for idx in valid_indices] - + valid_indices = [ + idx for idx, nan_status in enumerate(is_nan) if not nan_status + ] + valid_reliability = [ + reliability_metrics[idx] for idx in valid_indices + ] + # Select the most reliable valid source best_valid_idx = valid_indices[np.argmax(valid_reliability)] selected_sources[i] = best_valid_idx - + # Apply median filter to smooth source selection and avoid rapid switching if window_size > 1 and len(time_points) > window_size: selected_sources = medfilt(selected_sources, window_size) - + # Create the fused track by selecting values from the chosen source at each time fused_data = np.zeros((len(time_points), stacked.sizes["space"])) - - for i, (t, source_idx) in enumerate(zip(time_points, selected_sources)): + + for i, (t, source_idx) in enumerate( + zip(time_points, selected_sources, strict=False) + ): fused_data[i] = stacked.sel(time=t, source=source_idx).values - + # Create a new DataArray with the fused data fused = xr.DataArray( data=fused_data, dims=["time", "space"], - coords={ - "time": time_points, - "space": stacked.space.values - } + coords={"time": time_points, "space": stacked.space.values}, ) - + if print_report: print(report_nan_values(fused, "Fused track (reliability-based)")) - + return fused @log_to_attrs def fuse_tracks_kalman( - aligned_tracks: List[xr.DataArray], + aligned_tracks: list[xr.DataArray], process_noise_scale: float = 0.01, - measurement_noise_scales: List[float] = None, + measurement_noise_scales: list[float] = None, print_report: bool = False, ) -> xr.DataArray: """Fuse tracks using a Kalman filter. @@ -415,79 +431,80 @@ def fuse_tracks_kalman( 2. Predicts the next state based on constant velocity assumptions 3. Updates the prediction using measurements from all available sources 4. Handles missing measurements (NaNs) by skipping the update step - + The Kalman filter is particularly effective for trajectory smoothing and handling noisy measurements from multiple sources. + """ if not aligned_tracks: raise log_error(ValueError, "No tracks provided") - + n_tracks = len(aligned_tracks) - + # Set default measurement noise scales if not provided if measurement_noise_scales is None: measurement_noise_scales = [1.0] * n_tracks - + if len(measurement_noise_scales) != n_tracks: raise log_error( ValueError, f"Number of measurement noise scales ({len(measurement_noise_scales)}) " - f"does not match number of tracks ({n_tracks})" + f"does not match number of tracks ({n_tracks})", ) - + # Get the common time axis time_points = aligned_tracks[0].time.values n_timesteps = len(time_points) - + # Get the dimensionality of the space (2D or 3D) n_dims = len(aligned_tracks[0].space.values) - + # Initialize state vector [x, y, vx, vy] or [x, y, z, vx, vy, vz] state_dim = 2 * n_dims state = np.zeros(state_dim) - + # Initialize state covariance matrix state_cov = np.eye(state_dim) - + # Define transition matrix (constant velocity model) dt = 1.0 # Assuming unit time steps A = np.eye(state_dim) for i in range(n_dims): A[i, i + n_dims] = dt - + # Define process noise covariance Q = np.eye(state_dim) * process_noise_scale - + # Define measurement matrix (extracts position from state) H = np.zeros((n_dims, state_dim)) for i in range(n_dims): H[i, i] = 1.0 - + # Initialize storage for Kalman filter output kalman_output = np.zeros((n_timesteps, n_dims)) - + # For the first time step, initialize with the average of available measurements first_measurements = [] for track in aligned_tracks: pos = track.sel(time=time_points[0]).values if not np.isnan(pos).any(): first_measurements.append(pos) - + if first_measurements: initial_pos = np.mean(first_measurements, axis=0) state[:n_dims] = initial_pos kalman_output[0] = initial_pos - + # Run Kalman filter for t in range(1, n_timesteps): # Prediction step state = A @ state state_cov = A @ state_cov @ A.T + Q - + # Update step - combine all available measurements measurements = [] R_list = [] # Measurement noise covariances - + for i, track in enumerate(aligned_tracks): pos = track.sel(time=time_points[t]).values if not np.isnan(pos).any(): @@ -495,52 +512,51 @@ def fuse_tracks_kalman( # Measurement noise covariance for this source R = np.eye(n_dims) * measurement_noise_scales[i] R_list.append(R) - + # Skip update if no measurements available if not measurements: kalman_output[t] = state[:n_dims] continue - + # Apply update for each measurement - for z, R in zip(measurements, R_list): + for z, R in zip(measurements, R_list, strict=False): y = z - H @ state # Measurement residual S = H @ state_cov @ H.T + R # Residual covariance K = state_cov @ H.T @ np.linalg.inv(S) # Kalman gain state = state + K @ y # Updated state - state_cov = (np.eye(state_dim) - K @ H) @ state_cov # Updated covariance - + state_cov = ( + np.eye(state_dim) - K @ H + ) @ state_cov # Updated covariance + # Store the updated position kalman_output[t] = state[:n_dims] - + # Create a new DataArray with the Kalman filter output fused = xr.DataArray( data=kalman_output, dims=["time", "space"], - coords={ - "time": time_points, - "space": aligned_tracks[0].space.values - } + coords={"time": time_points, "space": aligned_tracks[0].space.values}, ) - + if print_report: print(report_nan_values(fused, "Fused track (Kalman filter)")) - + return fused @log_to_attrs def fuse_tracks( - datasets: List[xr.Dataset], - method: Union[str, FusionMethod] = "kalman", + datasets: list[xr.Dataset], + method: str | FusionMethod = "kalman", keypoint: str = "centroid", interpolate_gaps: bool = True, max_gap: int = 5, - weights: List[float] = None, - confidence_arrays: List[xr.DataArray] = None, - reliability_metrics: List[float] = None, + weights: list[float] = None, + confidence_arrays: list[xr.DataArray] = None, + reliability_metrics: list[float] = None, window_size: int = 11, process_noise_scale: float = 0.01, - measurement_noise_scales: List[float] = None, + measurement_noise_scales: list[float] = None, print_report: bool = False, ) -> xr.DataArray: """Fuse tracks from multiple datasets using the specified method. @@ -591,6 +607,7 @@ def fuse_tracks( ----- This function acts as a high-level interface to various track fusion methods, automatically handling dataset alignment and applying the selected fusion algorithm. + """ # Convert string method to enum if needed if isinstance(method, str): @@ -601,17 +618,17 @@ def fuse_tracks( "reliability": FusionMethod.RELIABILITY_BASED, "kalman": FusionMethod.KALMAN, } - + if method.lower() not in method_map: valid_methods = list(method_map.keys()) raise log_error( ValueError, f"Unsupported fusion method: {method}. " - f"Valid methods are: {valid_methods}" + f"Valid methods are: {valid_methods}", ) - + method = method_map[method.lower()] - + # Align datasets aligned_tracks = align_datasets( datasets=datasets, @@ -619,20 +636,20 @@ def fuse_tracks( interpolate=interpolate_gaps, max_gap=max_gap, ) - + # Apply fusion method if method == FusionMethod.MEAN: return fuse_tracks_mean( aligned_tracks=aligned_tracks, print_report=print_report, ) - + elif method == FusionMethod.MEDIAN: return fuse_tracks_median( aligned_tracks=aligned_tracks, print_report=print_report, ) - + elif method == FusionMethod.WEIGHTED: return fuse_tracks_weighted( aligned_tracks=aligned_tracks, @@ -640,7 +657,7 @@ def fuse_tracks( confidence_arrays=confidence_arrays, print_report=print_report, ) - + elif method == FusionMethod.RELIABILITY_BASED: return fuse_tracks_reliability( aligned_tracks=aligned_tracks, @@ -648,7 +665,7 @@ def fuse_tracks( window_size=window_size, print_report=print_report, ) - + elif method == FusionMethod.KALMAN: return fuse_tracks_kalman( aligned_tracks=aligned_tracks, @@ -656,9 +673,6 @@ def fuse_tracks( measurement_noise_scales=measurement_noise_scales, print_report=print_report, ) - + else: - raise log_error( - ValueError, - f"Unsupported fusion method: {method}" - ) \ No newline at end of file + raise log_error(ValueError, f"Unsupported fusion method: {method}") diff --git a/tests/test_unit/test_track_fusion.py b/tests/test_unit/test_track_fusion.py index da0ffa8c7..db383fb6b 100644 --- a/tests/test_unit/test_track_fusion.py +++ b/tests/test_unit/test_track_fusion.py @@ -71,18 +71,18 @@ def mock_datasets(): def test_align_datasets(mock_datasets): """Test aligning datasets with different time points.""" aligned = align_datasets(mock_datasets, interpolate=False) - + # Check that both arrays have the same time coordinates assert aligned[0].time.equals(aligned[1].time) - + # Check that NaNs are preserved when interpolate=False assert np.isnan(aligned[0].sel(time=3, space="x").values) assert np.isnan(aligned[1].sel(time=3, space="x").values) assert np.isnan(aligned[1].sel(time=7, space="x").values) - + # Test with interpolation aligned_interp = align_datasets(mock_datasets, interpolate=True) - + # Check that NaNs are interpolated assert not np.isnan(aligned_interp[0].sel(time=3, space="x").values) assert not np.isnan(aligned_interp[1].sel(time=3, space="x").values) @@ -93,15 +93,15 @@ def test_fuse_tracks_mean(mock_datasets): """Test mean fusion method.""" aligned = align_datasets(mock_datasets, interpolate=True) fused = fuse_tracks_mean(aligned) - + # Check output dimensions assert "source" not in fused.dims assert "time" in fused.dims assert "space" in fused.dims - + # Check that the fused track has all time points assert len(fused.time) == 10 - + # No NaNs when both sources are interpolated assert not np.isnan(fused).any() @@ -110,12 +110,12 @@ def test_fuse_tracks_median(mock_datasets): """Test median fusion method.""" aligned = align_datasets(mock_datasets, interpolate=True) fused = fuse_tracks_median(aligned) - + # Check output dimensions assert "source" not in fused.dims assert "time" in fused.dims assert "space" in fused.dims - + # No NaNs when both sources are interpolated assert not np.isnan(fused).any() @@ -123,23 +123,23 @@ def test_fuse_tracks_median(mock_datasets): def test_fuse_tracks_weighted(mock_datasets): """Test weighted fusion method.""" aligned = align_datasets(mock_datasets, interpolate=True) - + # Test with static weights weights = [0.7, 0.3] fused = fuse_tracks_weighted(aligned, weights=weights) - + # Check output dimensions assert "source" not in fused.dims assert "time" in fused.dims assert "space" in fused.dims - + # No NaNs when both sources are interpolated assert not np.isnan(fused).any() - + # Test with invalid weights (sum != 1) with pytest.raises(ValueError): fuse_tracks_weighted(aligned, weights=[0.5, 0.2]) - + # Test with mismatched weights length with pytest.raises(ValueError): fuse_tracks_weighted(aligned, weights=[0.5, 0.3, 0.2]) @@ -147,23 +147,27 @@ def test_fuse_tracks_weighted(mock_datasets): def test_fuse_tracks_reliability(mock_datasets): """Test reliability-based fusion method.""" - aligned = align_datasets(mock_datasets, interpolate=False) # Keep NaNs for testing - + aligned = align_datasets( + mock_datasets, interpolate=False + ) # Keep NaNs for testing + # Test with automatic reliability metrics fused = fuse_tracks_reliability(aligned) - + # Check output dimensions assert "source" not in fused.dims assert "time" in fused.dims assert "space" in fused.dims - + # Test with custom reliability metrics reliability_metrics = [0.9, 0.5] # First source more reliable - fused = fuse_tracks_reliability(aligned, reliability_metrics=reliability_metrics) - + fused = fuse_tracks_reliability( + aligned, reliability_metrics=reliability_metrics + ) + # Check that we still get a value for time point 7 where only source 1 has data assert not np.isnan(fused.sel(time=7, space="x").values) - + # Test with invalid window size (even number) with pytest.raises(ValueError): fuse_tracks_reliability(aligned, window_size=10) @@ -171,31 +175,33 @@ def test_fuse_tracks_reliability(mock_datasets): def test_fuse_tracks_kalman(mock_datasets): """Test Kalman filter fusion method.""" - aligned = align_datasets(mock_datasets, interpolate=False) # Keep NaNs for testing - + aligned = align_datasets( + mock_datasets, interpolate=False + ) # Keep NaNs for testing + # Test with default parameters fused = fuse_tracks_kalman(aligned) - + # Check output dimensions assert "source" not in fused.dims assert "time" in fused.dims assert "space" in fused.dims - + # Kalman filter should interpolate over missing values assert not np.isnan(fused).any() - + # Test with custom parameters fused = fuse_tracks_kalman( - aligned, - process_noise_scale=0.1, - measurement_noise_scales=[0.1, 0.5] + aligned, process_noise_scale=0.1, measurement_noise_scales=[0.1, 0.5] ) - + # Check that we get a smoother trajectory (less variance) x_vals = fused.sel(space="x").values diff = np.diff(x_vals) - assert np.std(diff) < 0.5 # Standard deviation of the differences should be low - + assert ( + np.std(diff) < 0.5 + ) # Standard deviation of the differences should be low + # Test with mismatched noise scales length with pytest.raises(ValueError): fuse_tracks_kalman(aligned, measurement_noise_scales=[0.1, 0.2, 0.3]) @@ -205,27 +211,27 @@ def test_fuse_tracks_high_level(mock_datasets): """Test the high-level fuse_tracks interface.""" # Test each method through the high-level interface methods = ["mean", "median", "weighted", "reliability", "kalman"] - + for method in methods: fused = fuse_tracks( datasets=mock_datasets, method=method, keypoint="centroid", - interpolate_gaps=True + interpolate_gaps=True, ) - + # Check output dimensions assert "time" in fused.dims assert "space" in fused.dims assert len(fused.space) == 2 - + # No NaNs when interpolation is used assert not np.isnan(fused).any() - + # Test with invalid method with pytest.raises(ValueError): fuse_tracks(mock_datasets, method="invalid_method") - + # Test with non-existent keypoint with pytest.raises(ValueError): - fuse_tracks(mock_datasets, keypoint="non_existent") \ No newline at end of file + fuse_tracks(mock_datasets, keypoint="non_existent")