diff --git a/docs/source/user_guide/input_output.md b/docs/source/user_guide/input_output.md index b29cb088b..743a8e135 100644 --- a/docs/source/user_guide/input_output.md +++ b/docs/source/user_guide/input_output.md @@ -255,10 +255,23 @@ save_poses.to_dlc_file(ds, "/path/to/file.csv", split_individuals=True) (target-saving-bboxes-tracks)= ## Saving bounding box tracks -We currently do not provide explicit methods to export a movement bounding boxes dataset in a specific format. However, you can easily save the bounding box tracks to a .csv file using the standard Python library `csv`. +We currently support exporting a [movement bboxes datasets](target-poses-and-bboxes-dataset) as a [VIA tracks .csv file](via:docs/face_track_annotation.html), so that you can visualise and correct your bounding box tracks with the [VGG Image Annotator (VIA-2) software](via:via.html). -Here is an example of how you can save a bounding boxes dataset to a .csv file: +To export your bounding boxes dataset `ds`, you will need to import the {mod}`movement.io.save_bboxes` module: +```python +from movement.io import save_bboxes +``` + +Then you can save it as a VIA tracks .csv file: +```python +save_bboxes.to_via_tracks_file(ds, "/path/to/output/file.csv") +``` + +By default the {func}`movement.io.save_bboxes.to_via_tracks_file` function will try to extract the track IDs from the individuals' names, but you can also select to extract them from the sorted list of individuals with `extract_track_id_from_individuals=True`. + + +Alternatively, you can save the bounding box tracks to a .csv file with a custom header using the standard Python library `csv`. Below is an example of how you can do this: ```python # define name for output csv file filepath = "tracking_output.csv" @@ -279,7 +292,7 @@ with open(filepath, mode="w", newline="") as file: writer.writerow([frame, individual, x, y, width, height, confidence]) ``` -Alternatively, we can convert the `movement` bounding boxes dataset to a pandas DataFrame with the {meth}`xarray.DataArray.to_dataframe` method, wrangle the dataframe as required, and then apply the {meth}`pandas.DataFrame.to_csv` method to save the data as a .csv file. +Or if you prefer to work with `pandas`, you can convert the `movement` bounding boxes dataset to a `pandas` DataFrame with the {meth}`xarray.DataArray.to_dataframe` method, wrangle the dataframe as required, and then apply the {meth}`pandas.DataFrame.to_csv` method to save the data as a .csv file. (target-sample-data)= diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 92034989d..226809dad 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -156,13 +156,13 @@ def from_file( ) -> xr.Dataset: """Create a ``movement`` bounding boxes dataset from a supported file. - At the moment, we only support VIA-tracks .csv files. + At the moment, we only support VIA tracks .csv files. Parameters ---------- file_path : pathlib.Path or str Path to the file containing the tracked bounding boxes. Currently - only VIA-tracks .csv files are supported. + only VIA tracks .csv files are supported. source_software : "VIA-tracks". The source software of the file. Currently only files from the VIA 2.0.12 annotator [1]_ ("VIA-tracks") are supported. diff --git a/movement/io/save_bboxes.py b/movement/io/save_bboxes.py new file mode 100644 index 000000000..8d83cbb95 --- /dev/null +++ b/movement/io/save_bboxes.py @@ -0,0 +1,543 @@ +"""Save bounding boxes data from ``movement`` to VIA tracks .csv format.""" + +import _csv +import csv +import json +from pathlib import Path + +import numpy as np +import xarray as xr + +from movement.io.utils import _validate_file_path +from movement.utils.logging import logger +from movement.validators.datasets import ValidBboxesDataset + + +def to_via_tracks_file( + ds: xr.Dataset, + file_path: str | Path, + extract_track_id_from_individuals: bool = True, + frame_n_digits: int | None = None, + image_file_prefix: str | None = None, + image_file_suffix: str = ".png", +) -> Path: + """Save a movement bounding boxes dataset to a VIA tracks .csv file. + + Parameters + ---------- + ds : xarray.Dataset + The movement bounding boxes dataset to export. + file_path : str or pathlib.Path + Path where the VIA tracks .csv file [1]_ will be saved. + extract_track_id_from_individuals : bool, optional + If True, extract track IDs from the numbers at the end of the + individuals' names (e.g. `mouse_1` -> track ID 1). If False, the + track IDs will be factorised from the list of sorted individuals' + names. Default is True. + frame_n_digits : int, optional + The number of digits to use to represent frame numbers in the image + filenames (including leading zeros). If None, the number of digits is + automatically determined from the largest frame number in the dataset, + plus one (to have at least one leading zero). Default is None. + image_file_prefix : str, optional + Prefix to apply to every image filename. It is prepended to the frame + number which is padded with leading zeros. If None or an empty string, + nothing will be prepended to the padded frame number. Default is None. + image_file_suffix : str, optional + Suffix to add to every image filename holding the file extension. + Strings with or without the dot are accepted. Default is '.png'. + + Returns + ------- + pathlib.Path + Path to the saved file. + + References + ---------- + .. [1] https://www.robots.ox.ac.uk/~vgg/software/via/docs/face_track_annotation.html + + Examples + -------- + Export a ``movement`` bounding boxes dataset as a VIA tracks .csv file, + deriving the track IDs from the list of individuals' names and assuming + the image files are PNG files: + + >>> from movement.io import save_boxes + >>> save_boxes.to_via_tracks_file(ds, "/path/to/output.csv") + + Export a ``movement`` bounding boxes dataset as a VIA tracks .csv file, + deriving the track IDs from the list of sorted individuals' names and + assuming the image files are JPG files: + + >>> from movement.io import save_boxes + >>> save_boxes.to_via_tracks_file( + ... ds, + ... "/path/to/output.csv", + ... extract_track_id_from_individuals=False, + ... image_file_suffix=".jpg", + ... ) + + Export a ``movement`` bounding boxes dataset as a VIA tracks .csv file, + with image filenames following the format ``frame-.jpg``: + + >>> from movement.io import save_boxes + >>> save_boxes.to_via_tracks_file( + ... ds, + ... "/path/to/output.csv", + ... image_file_prefix="frame-", + ... image_file_suffix=".jpg", + ... ) + + Export a ``movement`` bounding boxes dataset as a VIA tracks .csv file, + with frame numbers represented with 4 digits, including leading zeros + (i.e., image filenames would be ``0000.png``, ``0001.png``, etc.): + + >>> from movement.io import save_boxes + >>> save_boxes.to_via_tracks_file( + ... ds, + ... "/path/to/output.csv", + ... frame_n_digits=4, + ... ) + + """ + # Validate file path and dataset + file = _validate_file_path(file_path, expected_suffix=[".csv"]) + _validate_bboxes_dataset(ds) + + # Check the number of digits required to represent the frame numbers + frame_n_digits = _check_frame_required_digits( + ds=ds, frame_n_digits=frame_n_digits + ) + + # Define format string for image filenames + img_filename_template = _get_image_filename_template( + frame_n_digits=frame_n_digits, + image_file_prefix=image_file_prefix, + image_file_suffix=image_file_suffix, + ) + + # Map individuals' names to track IDs + map_individual_to_track_id = _get_map_individuals_to_track_ids( + ds.coords["individuals"].values, + extract_track_id_from_individuals, + ) + + # Write file + _write_via_tracks_csv( + ds, + file.path, + map_individual_to_track_id, + img_filename_template, + ) + + logger.info(f"Saved bounding boxes dataset to {file.path}.") + return file.path + + +def _validate_bboxes_dataset(ds: xr.Dataset) -> None: + """Verify the input dataset is a valid ``movement`` bboxes dataset. + + Parameters + ---------- + ds : xarray.Dataset + Dataset to validate. + + Raises + ------ + TypeError + If the input is not an xarray Dataset. + ValueError + If the dataset is missing required data variables or dimensions + for a valid ``movement`` bboxes dataset. + + """ + if not isinstance(ds, xr.Dataset): + raise logger.error( + TypeError(f"Expected an xarray Dataset, but got {type(ds)}.") + ) + + missing_vars = set(ValidBboxesDataset.VAR_NAMES) - set(ds.data_vars) + if missing_vars: + raise ValueError( + f"Missing required data variables: {sorted(missing_vars)}" + ) # sort for a reproducible error message + + missing_dims = set(ValidBboxesDataset.DIM_NAMES) - set(ds.dims) + if missing_dims: + raise ValueError( + f"Missing required dimensions: {sorted(missing_dims)}" + ) # sort for a reproducible error message + + +def _get_image_filename_template( + frame_n_digits: int, + image_file_prefix: str | None, + image_file_suffix: str, +) -> str: + """Compute a format string for the images' filenames. + + The filenames of the images in the VIA tracks .csv file are derived from + the frame numbers. Optionally, a prefix can be added to the frame number. + The suffix refers to the file extension of the image files. + + Parameters + ---------- + frame_n_digits : int + Number of digits used to represent the frame number, including any + leading zeros. + image_file_prefix : str | None + Prefix for each image filename, prepended to the frame number. If + None or an empty string, nothing will be prepended. + image_file_suffix : str + Suffix to add to each image filename to represent the file extension. + + Returns + ------- + str + Format string for the images' filenames. + + """ + # Add the dot to the file extension if required + if not image_file_suffix.startswith("."): + image_file_suffix = f".{image_file_suffix}" + + # Add the prefix if not None or not an empty string + image_file_prefix_modified = ( + f"{image_file_prefix}" if image_file_prefix else "" + ) + + # Define filename format string + return ( + f"{image_file_prefix_modified}" + f"{{:0{frame_n_digits}d}}" + f"{image_file_suffix}" + ) + + +def _check_frame_required_digits( + ds: xr.Dataset, + frame_n_digits: int | None, +) -> int: + """Check the number of digits to represent the frame number is valid. + + Parameters + ---------- + ds : xarray.Dataset + A movement dataset. + frame_n_digits : int | None + The proposed number of digits to use to represent the frame numbers + in the image filenames (including leading zeros). If None, the number + of digits is inferred based on the largest frame number in the dataset. + + Returns + ------- + int + The number of digits to use to represent the frame numbers in the + image filenames (including leading zeros). + + Raises + ------ + ValueError + If the proposed number of digits is not enough to represent all the + frame numbers. + + """ + # Compute minimum number of digits required to represent the + # largest frame number + if ds.time_unit == "seconds": + max_frame_number = max((ds.time.values * ds.fps).astype(int)) + else: + max_frame_number = max(ds.time.values) + min_required_digits = len(str(max_frame_number)) + + # If requested number of digits is None, infer automatically + if frame_n_digits is None: + return min_required_digits + 1 # pad with at least one zero + elif frame_n_digits < min_required_digits: + raise ValueError( + "The requested number of digits to represent the frame " + "number cannot be used to represent all the frame numbers." + f"Got {frame_n_digits}, but the maximum frame number has " + f"{min_required_digits} digits" + ) + else: + return frame_n_digits + + +def _get_map_individuals_to_track_ids( + list_individuals: list[str], + extract_track_id_from_individuals: bool, +) -> dict[str, int]: + """Compute a mapping of individuals' names to track IDs. + + Parameters + ---------- + list_individuals : list[str] + List of individuals' names. + extract_track_id_from_individuals : bool + If True, extract track ID from the last consecutive digits in + the individuals' names. If False, the track IDs will be factorised + from the sorted list of individuals' names. + + Returns + ------- + dict[str, int] + A dictionary mapping individuals' names to track IDs. + + """ + if extract_track_id_from_individuals: + # Extract track IDs from the individuals' names + map_individual_to_track_id = _get_track_id_from_individuals( + list_individuals + ) + else: + # Factorise track IDs from sorted individuals' names + list_individuals = sorted(list_individuals) + map_individual_to_track_id = { + individual: i for i, individual in enumerate(list_individuals) + } + + return map_individual_to_track_id + + +def _get_track_id_from_individuals( + list_individuals: list[str], +) -> dict[str, int]: + """Extract track IDs as the last digits in the individuals' names. + + Parameters + ---------- + list_individuals : list[str] + List of individuals' names. + + Returns + ------- + dict[str, int] + A dictionary mapping individuals' names to track IDs. + + Raises + ------ + ValueError + If a track ID is not found by looking at the last consecutive digits + in an individual's name, or if the extracted track IDs cannot be + uniquely mapped to the individuals' names. + + """ + map_individual_to_track_id = {} + + for individual in list_individuals: + # Find the first non-digit character starting from the end + last_idx = len(individual) - 1 + first_non_digit_idx = last_idx + while ( + first_non_digit_idx >= 0 + and individual[first_non_digit_idx].isdigit() + ): + first_non_digit_idx -= 1 + + # Extract track ID from (first_non_digit_idx+1) until the end + if first_non_digit_idx < last_idx: + track_id = int(individual[first_non_digit_idx + 1 :]) + map_individual_to_track_id[individual] = track_id + else: + raise ValueError(f"Could not extract track ID from {individual}.") + + # Check that all individuals have a unique track ID + if len(set(map_individual_to_track_id.values())) != len( + set(list_individuals) + ): + raise ValueError( + "Could not extract a unique track ID for all individuals. " + f"Expected {len(set(list_individuals))} unique track IDs, " + f"but got {len(set(map_individual_to_track_id.values()))}." + ) + + return map_individual_to_track_id + + +def _write_via_tracks_csv( + ds: xr.Dataset, + file_path: str | Path, + map_individual_to_track_id: dict, + img_filename_template: str, +) -> None: + """Write a VIA tracks .csv file. + + Parameters + ---------- + ds : xarray.Dataset + A movement bounding boxes dataset. + file_path : str or pathlib.Path + Path where the VIA tracks .csv file will be saved. + map_individual_to_track_id : dict + Dictionary mapping individuals' names to track IDs. + img_filename_template : str + Format string for the images' filenames. + + """ + # Define VIA tracks .csv header + header = [ + "filename", + "file_size", + "file_attributes", + "region_count", + "region_id", + "region_shape_attributes", + "region_attributes", + ] + + # Get time values in frames + if ds.time_unit == "seconds": + time_in_frames = (ds.time.values * ds.fps).astype(int) + else: + time_in_frames = ds.time.values + + # Locate bboxes with null position or shape + null_position_or_shape = np.any(ds.position.isnull(), axis=1) | np.any( + ds.shape.isnull(), axis=1 + ) # (time, individuals) + + with open(file_path, "w", newline="") as f: + csv_writer = csv.writer(f) + csv_writer.writerow(header) + + # Loop through frames + for time_idx, time in enumerate(ds.time.values): + frame_number = time_in_frames[time_idx] + + # Compute region count for current frame + region_count = int(np.sum(~null_position_or_shape[time_idx, :])) + + # Initialise region ID for current frame + region_id = 0 + + # Loop through individuals + for indiv in ds.individuals.values: + # Get position and shape data + xy_data = ds.position.sel(time=time, individuals=indiv).values + wh_data = ds.shape.sel(time=time, individuals=indiv).values + + # If the position or shape data contain NaNs, do not write + # this bounding box to file + if np.isnan(xy_data).any() or np.isnan(wh_data).any(): + continue + + # Get confidence score + confidence = ds.confidence.sel( + time=time, individuals=indiv + ).values + if np.isnan(confidence): + confidence = None # pass as None if confidence is NaN + + # Get track IDs from individuals' names + track_id = map_individual_to_track_id[indiv] + + # Write row + _write_single_row( + csv_writer, + xy_data, + wh_data, + confidence, + track_id, + region_count, + region_id, + img_filename_template.format(frame_number), + image_size=None, + ) + + # Update region ID for this frame + region_id += 1 + + +def _write_single_row( + writer: "_csv._writer", # requires a string literal type annotation + xy_values: np.ndarray, + wh_values: np.ndarray, + confidence: float | None, + track_id: int, + region_count: int, + region_id: int, + img_filename: str, + image_size: int | None, +) -> tuple[str, int, str, int, int, str, str]: + """Write a single row of a VIA tracks .csv file and return it as a tuple. + + Parameters + ---------- + writer : csv.writer + CSV writer object. + xy_values : np.ndarray + Array with the x, y coordinates of the bounding box centroid. + wh_values : np.ndarray + Array with the width and height of the bounding box. + confidence : float | None + Confidence score for the bounding box detection. + track_id : int + Integer identifying a single track of bounding boxes across frames. + region_count : int + Total number of bounding boxes in the current frame. + region_id : int + Integer that identifies the bounding boxes in a frame starting from 0. + Note that it is the result of an enumeration, and it does not + necessarily match the track ID. + img_filename : str + Filename of the image file corresponding to the current frame. + image_size : int | None + File size in bytes. If None, the file size is set to 0. + + Returns + ------- + tuple[str, int, str, int, int, str, str] + A tuple with the data formatted for a single row in a VIA-tracks + .csv file. + + Notes + ----- + The reference for the VIA tracks .csv file format is at + https://www.robots.ox.ac.uk/~vgg/software/via/docs/face_track_annotation.html + + """ + # Calculate top-left coordinates of bounding box + x_center, y_center = xy_values + width, height = wh_values + x_top_left = x_center - width / 2 + y_top_left = y_center - height / 2 + + # Define file attributes (placeholder value) + file_attributes = json.dumps({"shot": 0}) + + # Define region shape attributes + region_shape_attributes = json.dumps( + { + "name": "rect", + "x": float(x_top_left), + "y": float(y_top_left), + "width": float(width), + "height": float(height), + } + ) + + # Define region attributes + region_attributes_dict: dict[str, float | int] = {"track": int(track_id)} + if confidence is not None: + # convert to float to ensure it is json-serializable + region_attributes_dict["confidence"] = float(confidence) + region_attributes = json.dumps(region_attributes_dict) + + # Set image size + image_size = int(image_size) if image_size is not None else 0 + + # Define row data + row = ( + img_filename, + image_size, + file_attributes, + region_count, + region_id, + region_shape_attributes, + region_attributes, + ) + + writer.writerow(row) + + return row diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index e65bd481e..6b1964d76 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -8,9 +8,9 @@ import pandas as pd import xarray as xr +from movement.io.utils import _validate_file_path from movement.utils.logging import logger from movement.validators.datasets import ValidPosesDataset -from movement.validators.files import ValidFile def _ds_to_dlc_style_df( @@ -112,7 +112,7 @@ def to_dlc_style_df( to_dlc_file : Save dataset directly to a DeepLabCut-style .h5 or .csv file. """ - _validate_dataset(ds) + _validate_poses_dataset(ds) scorer = ["movement"] bodyparts = ds.coords["keypoints"].data.tolist() coords = ds.coords["space"].data.tolist() + ["likelihood"] @@ -253,7 +253,7 @@ def to_lp_file( """ file = _validate_file_path(file_path=file_path, expected_suffix=[".csv"]) - _validate_dataset(ds) + _validate_poses_dataset(ds) to_dlc_file(ds, file.path, split_individuals=True) @@ -297,7 +297,7 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None: """ file = _validate_file_path(file_path=file_path, expected_suffix=[".h5"]) - _validate_dataset(ds) + _validate_poses_dataset(ds) ds = _remove_unoccupied_tracks(ds) @@ -380,47 +380,8 @@ def _remove_unoccupied_tracks(ds: xr.Dataset): return ds.where(~all_nan, drop=True) -def _validate_file_path( - file_path: str | Path, expected_suffix: list[str] -) -> ValidFile: - """Validate the input file path. - - We check that the file has write permission and the expected suffix(es). - - Parameters - ---------- - file_path : pathlib.Path or str - Path to the file to validate. - expected_suffix : list of str - Expected suffix(es) for the file. - - Returns - ------- - ValidFile - The validated file. - - Raises - ------ - OSError - If the file cannot be written. - ValueError - If the file does not have the expected suffix. - - """ - try: - file = ValidFile( - file_path, - expected_permission="w", - expected_suffix=expected_suffix, - ) - except (OSError, ValueError) as error: - logger.error(error) - raise - return file - - -def _validate_dataset(ds: xr.Dataset) -> None: - """Validate the input as a proper ``movement`` dataset. +def _validate_poses_dataset(ds: xr.Dataset) -> None: + """Validate the input as a proper ``movement`` poses dataset. Parameters ---------- @@ -432,7 +393,8 @@ def _validate_dataset(ds: xr.Dataset) -> None: TypeError If the input is not an xarray Dataset. ValueError - If the dataset is missing required data variables or dimensions. + If the dataset is missing required data variables or dimensions + for a valid ``movement`` poses dataset. """ if not isinstance(ds, xr.Dataset): diff --git a/movement/io/utils.py b/movement/io/utils.py new file mode 100644 index 000000000..0aeace27a --- /dev/null +++ b/movement/io/utils.py @@ -0,0 +1,45 @@ +"""Functions shared across the ``movement`` IO module.""" + +from pathlib import Path + +from movement.utils.logging import logger +from movement.validators.files import ValidFile + + +def _validate_file_path( + file_path: str | Path, expected_suffix: list[str] +) -> ValidFile: + """Validate the input file path. + + We check that the file has write permission and the expected suffix(es). + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file to validate. + expected_suffix : list of str + Expected suffix(es) for the file. + + Returns + ------- + ValidFile + The validated file. + + Raises + ------ + OSError + If the file cannot be written. + ValueError + If the file does not have the expected suffix. + + """ + try: + file = ValidFile( + file_path, + expected_permission="w", + expected_suffix=expected_suffix, + ) + except (OSError, ValueError) as error: + logger.error(error) + raise + return file diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 14e5169e6..0c6f0e9fb 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -56,7 +56,7 @@ def valid_bboxes_arrays(): position[:, 1, i] = (-1) ** i * np.arange(n_frames) # build a valid array for constant bbox shape (60, 40) - constant_shape = (60, 40) # width, height in pixels + constant_shape = float(60), float(40) # width, height in pixels shape = np.tile(constant_shape, (n_frames, n_individuals, 1)).transpose( 0, 2, 1 ) @@ -82,6 +82,15 @@ def valid_bboxes_arrays(): def valid_bboxes_dataset(valid_bboxes_arrays): """Return a valid bboxes dataset for two individuals moving in uniform linear motion, with 5 frames with low confidence values and time in frames. + + It represents 2 individuals for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values are set to 0.9 except the following which are set + to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 """ dim_names = ValidBboxesDataset.DIM_NAMES @@ -118,6 +127,7 @@ def valid_bboxes_dataset_in_seconds(valid_bboxes_dataset): """Return a valid bboxes dataset with time in seconds. The origin of time is assumed to be time = frame 0 = 0 seconds. + The time unit is set to "seconds" and the fps is set to 60. """ fps = 60 valid_bboxes_dataset["time"] = valid_bboxes_dataset.time / fps diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_io/test_load_bboxes.py similarity index 100% rename from tests/test_unit/test_load_bboxes.py rename to tests/test_unit/test_io/test_load_bboxes.py diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py similarity index 100% rename from tests/test_unit/test_load_poses.py rename to tests/test_unit/test_io/test_load_poses.py diff --git a/tests/test_unit/test_io/test_save_bboxes.py b/tests/test_unit/test_io/test_save_bboxes.py new file mode 100644 index 000000000..8be2124c0 --- /dev/null +++ b/tests/test_unit/test_io/test_save_bboxes.py @@ -0,0 +1,751 @@ +import json +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from movement.io import load_bboxes, save_bboxes +from movement.io.save_bboxes import ( + _get_map_individuals_to_track_ids, + _write_single_row, +) + + +@pytest.fixture +def mock_csv_writer(): + """Return a mock CSV writer object.""" + # Mock object + writer = Mock() + # Add writerow method to the mock object + writer.writerow = Mock() + return writer + + +@pytest.fixture +def valid_bboxes_dataset_min_frame_number_modified(valid_bboxes_dataset): + """Return a valid bbboxes dataset with data for 10 frames, + starting at frame number 333. + + `valid_bboxes_dataset` is a dataset with the time coordinate in + frames and data for 10 frames. + """ + return valid_bboxes_dataset.assign_coords( + time=valid_bboxes_dataset.time + 333 + ) + + +@pytest.fixture +def valid_bboxes_dataset_with_late_id0(valid_bboxes_dataset): + """Return a valid bboxes dataset with id_0 starting at time index 3. + + `valid_bboxes_dataset` represents two individuals moving in uniform + linear motion for 10 frames, with low confidence values and time in frames. + """ + valid_bboxes_dataset.position.loc[ + {"individuals": "id_0", "time": [0, 1, 2]} + ] = np.nan + return valid_bboxes_dataset + + +@pytest.fixture +def valid_bboxes_dataset_individuals_modified(valid_bboxes_dataset): + """Return a valid bboxes dataset with individuals named "id_333" and + "id_444". + """ + valid_bboxes_dataset.assign_coords(individuals=["id_333", "id_444"]) + return valid_bboxes_dataset + + +@pytest.fixture +def valid_bboxes_dataset_confidence_all_nans(valid_bboxes_dataset): + """Return a valid bboxes dataset with all NaNs in + the confidence array. + """ + valid_bboxes_dataset["confidence"] = xr.DataArray( + data=np.nan, + dims=valid_bboxes_dataset.confidence.dims, + coords=valid_bboxes_dataset.confidence.coords, + ) + return valid_bboxes_dataset + + +@pytest.fixture +def valid_bboxes_dataset_confidence_some_nans(valid_bboxes_dataset): + """Return a valid bboxes dataset with some NaNs in + the confidence array. + + `valid_bboxes_dataset` represents two individuals moving in uniform + linear motion for 10 frames, with time in frames. The confidence values + for the first 3 frames for individual 0 are set to NaN. + """ + # Set first 3 frames for individual 0 to NaN + confidence_array = valid_bboxes_dataset.confidence.values + confidence_array[:3, 0] = np.nan + + valid_bboxes_dataset["confidence"] = xr.DataArray( + data=confidence_array, + dims=valid_bboxes_dataset.confidence.dims, + coords=valid_bboxes_dataset.confidence.coords, + ) + return valid_bboxes_dataset + + +def _get_min_required_digits_in_ds(ds): + """Return the minimum number of digits required to represent the + largest frame number in the input dataset. + """ + # Compute the maximum frame number + max_frame_number = max(ds.time.values) + if "seconds" in ds.time_unit: + max_frame_number = int(max_frame_number * ds.fps) + + # Return the minimum number of digits required to represent the + # largest frame number + return len(str(max_frame_number)) + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_bboxes_dataset", + "valid_bboxes_dataset_in_seconds", + "valid_bboxes_dataset_with_nan", # nans in position array + "valid_bboxes_dataset_with_late_id0", + ], +) +def test_to_via_tracks_file_valid_dataset( + valid_dataset, + tmp_path, + request, +): + """Test the VIA tracks .csv file with different valid bboxes datasets.""" + # Save VIA tracks .csv file + input_dataset = request.getfixturevalue(valid_dataset) + output_path = tmp_path / "test_valid_dataset.csv" + save_bboxes.to_via_tracks_file(input_dataset, output_path) + + # Check that the exported file is readable in movement + if input_dataset.time_unit == "seconds": + ds = load_bboxes.from_via_tracks_file( + output_path, fps=input_dataset.fps + ) + else: + ds = load_bboxes.from_via_tracks_file(output_path) + + # Check the dataset matches the original one. + # If the position or shape data arrays contain NaNs, remove those + # data points from the original dataset before comparing (these bboxes + # are skipped when writing the VIA tracks .csv file) + null_position_or_shape = ( + input_dataset.position.isnull() | input_dataset.shape.isnull() + ) + input_dataset.shape.values[null_position_or_shape] = np.nan + input_dataset.position.values[null_position_or_shape] = np.nan + input_dataset.confidence.values[np.any(null_position_or_shape, axis=1)] = ( + np.nan + ) + xr.testing.assert_equal(ds, input_dataset) + + +@pytest.mark.parametrize( + "image_file_prefix", + [None, "test_video"], +) +@pytest.mark.parametrize( + "image_file_suffix", + [None, ".png", "png", ".jpg"], +) +def test_to_via_tracks_file_image_filename( + valid_bboxes_dataset, + image_file_prefix, + image_file_suffix, + tmp_path, +): + """Test the VIA tracks .csv export with different image file prefixes and + suffixes. + """ + # Prepare kwargs + kwargs = {"image_file_prefix": image_file_prefix} + if image_file_suffix is not None: + kwargs["image_file_suffix"] = image_file_suffix + + # Save VIA tracks .csv file + output_path = tmp_path / "test_valid_dataset.csv" + save_bboxes.to_via_tracks_file( + valid_bboxes_dataset, + output_path, + **kwargs, + ) + + # Check image file prefix is as expected + df = pd.read_csv(output_path) + if image_file_prefix is not None: + assert df["filename"].str.startswith(image_file_prefix).all() + else: + assert df["filename"].str.startswith("0").all() + + # Check image file suffix is as expected + if image_file_suffix is not None: + assert df["filename"].str.endswith(image_file_suffix).all() + else: + assert df["filename"].str.endswith(".png").all() + + +@pytest.mark.parametrize( + "valid_dataset, expected_confidence_nan_count", + [ + ("valid_bboxes_dataset", 0), + # all bboxes should have a confidence value + ("valid_bboxes_dataset_confidence_all_nans", 20), + # some bboxes should have a confidence value + ("valid_bboxes_dataset_confidence_some_nans", 3), + # no bboxes should have a confidence value + ], +) +def test_to_via_tracks_file_confidence( + valid_dataset, + expected_confidence_nan_count, + tmp_path, + request, +): + """Test that the VIA tracks .csv file is as expected when the confidence + array contains NaNs. + """ + # Save VIA tracks .csv file + input_dataset = request.getfixturevalue(valid_dataset) + output_path = tmp_path / "test_valid_dataset.csv" + save_bboxes.to_via_tracks_file(input_dataset, output_path) + + # Check that the input dataset has the expected number of NaNs in the + # confidence array + confidence_is_nan = input_dataset.confidence.isnull().values + assert np.sum(confidence_is_nan) == expected_confidence_nan_count + + # Check that the confidence values in the exported file match the dataset + df = pd.read_csv(output_path) + df["region_attributes"] = [ + json.loads(el) for el in df["region_attributes"] + ] + + # Check the "confidence" region attribute is present for + # as many rows as there are non-NaN confidence values + assert sum( + ["confidence" in row for row in df["region_attributes"]] + ) == np.sum(~confidence_is_nan) + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_bboxes_dataset", + # individuals: "id_0", "id_1" + "valid_bboxes_dataset_individuals_modified", + # individuals: "id_333", "id_444" + ], +) +@pytest.mark.parametrize( + "extract_track_id_from_individuals", + [True, False], +) +def test_to_via_tracks_file_extract_track_id_from_individuals( + valid_dataset, + extract_track_id_from_individuals, + tmp_path, + request, +): + """Test that the VIA tracks .csv file is as expected when extracting + track IDs from the individuals' names. + """ + # Save VIA tracks .csv file + output_path = tmp_path / "test_valid_dataset.csv" + input_dataset = request.getfixturevalue(valid_dataset) + save_bboxes.to_via_tracks_file( + input_dataset, + output_path, + extract_track_id_from_individuals=extract_track_id_from_individuals, + ) + + # Check track ID in relation to individuals' names + df = pd.read_csv(output_path) + df["region_attributes"] = [ + json.loads(el) for el in df["region_attributes"] + ] + set_unique_track_ids = set( + [int(row["track"]) for row in df["region_attributes"]] + ) + + # Note: we check if the sets of IDs is as expected, regardless of the order + if extract_track_id_from_individuals: + assert set_unique_track_ids == set( + [ + int(indiv.split("_")[1]) + for indiv in input_dataset.individuals.values + ] + ) + else: + assert set_unique_track_ids == {0, 1} + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_bboxes_dataset", + "valid_bboxes_dataset_with_nan", + "valid_bboxes_dataset_with_late_id0", + ], +) +def test_to_via_tracks_file_region_count_and_id( + valid_dataset, tmp_path, request +): + """Test that the region count and region ID are as expected.""" + # Save VIA tracks .csv file + output_path = tmp_path / "test_valid_dataset.csv" + input_dataset = request.getfixturevalue(valid_dataset) + save_bboxes.to_via_tracks_file(input_dataset, output_path) + + # Read output file as a dataframe + df = pd.read_csv(output_path) + + # Check that the region count matches the number of annotations + # per filename + df_bboxes_count = df["filename"].value_counts(sort=False) + map_filename_to_bboxes_count = { + filename: count + for filename, count in zip( + df_bboxes_count.index, + df_bboxes_count, + strict=True, + ) + } + assert all( + df["region_count"].values + == [map_filename_to_bboxes_count[fn] for fn in df["filename"]] + ) + + # Check that the region ID per filename ranges from 0 to the + # number of annotations per filename + assert all( + np.all( + df["region_id"].values[df["filename"] == fn] + == np.array(range(map_filename_to_bboxes_count[fn])) + ) + for fn in df["filename"] + ) + + +@pytest.mark.parametrize( + "invalid_dataset, expected_exception", + [ + ("not_a_dataset", TypeError), + ("empty_dataset", ValueError), + ("missing_var_bboxes_dataset", ValueError), + ("missing_two_vars_bboxes_dataset", ValueError), + ("missing_dim_bboxes_dataset", ValueError), + ("missing_two_dims_bboxes_dataset", ValueError), + ], +) +def test_to_via_tracks_file_invalid_dataset( + invalid_dataset, expected_exception, request, tmp_path +): + """Test that an invalid dataset raises an error.""" + with pytest.raises(expected_exception): + save_bboxes.to_via_tracks_file( + request.getfixturevalue(invalid_dataset), + tmp_path / "test_invalid_dataset.csv", + ) + + +@pytest.mark.parametrize( + "wrong_extension", + [ + ".mp4", + "", + ], +) +def test_to_via_tracks_file_invalid_file_path( + valid_bboxes_dataset, tmp_path, wrong_extension +): + """Test that file with wrong extension raises an error.""" + with pytest.raises(ValueError): + save_bboxes.to_via_tracks_file( + valid_bboxes_dataset, + tmp_path / f"test{wrong_extension}", + ) + + +@pytest.mark.parametrize( + "frame_n_digits", + [1, 100], + ids=["1_digit", "100_digits"], +) +@pytest.mark.parametrize( + "image_file_prefix, expected_prefix", + [ + (None, ""), + ("", ""), + ("test_video", "test_video"), + ("test_video_", "test_video_"), + ], + ids=["no_prefix", "empty_prefix", "prefix", "prefix_underscore"], +) +@pytest.mark.parametrize( + "image_file_suffix, expected_suffix", + [ + (".png", ".png"), + ("png", ".png"), + (".jpg", ".jpg"), + ], + ids=["png_extension", "png_no_dot", "jpg_extension"], +) +def test_get_image_filename_template( + frame_n_digits, + image_file_prefix, + expected_prefix, + image_file_suffix, + expected_suffix, +): + """Test that the image filename template is as expected.""" + expected_image_filename = ( + f"{expected_prefix}{{:0{frame_n_digits}d}}{expected_suffix}" + ) + assert ( + save_bboxes._get_image_filename_template( + frame_n_digits=frame_n_digits, + image_file_prefix=image_file_prefix, + image_file_suffix=image_file_suffix, + ) + == expected_image_filename + ) + + +@pytest.mark.parametrize( + "valid_dataset_str,", + [ + ("valid_bboxes_dataset"), + ("valid_bboxes_dataset_in_seconds"), + ("valid_bboxes_dataset_min_frame_number_modified"), + ], + ids=["min_2_digits", "min_2_digits_in_seconds", "min_3_digits"], +) +@pytest.mark.parametrize( + "frame_n_digits", + [None, 7], + ids=["auto", "user"], +) +def test_get_min_required_digits_in_ds( + valid_dataset_str, + frame_n_digits, + request, +): + """Test that the number of digits to represent the frame number is + computed as expected. + """ + ds = request.getfixturevalue(valid_dataset_str) + min_required_digits = _get_min_required_digits_in_ds(ds) + + # Compute expected number of digits in output + if frame_n_digits is None: + expected_out_digits = min_required_digits + 1 + else: + expected_out_digits = frame_n_digits + + # Check the number of digits to use in the output is as expected + assert ( + save_bboxes._check_frame_required_digits( + ds=ds, frame_n_digits=frame_n_digits + ) + == expected_out_digits + ) + + +@pytest.mark.parametrize( + "valid_dataset_str, requested_n_digits", + [ + ("valid_bboxes_dataset", 0), + ("valid_bboxes_dataset_min_frame_number_modified", 2), + ], + ids=["min_2_digits", "min_3_digits"], +) +def test_get_min_required_digits_in_ds_error( + valid_dataset_str, requested_n_digits, request +): + """Test that an error is raised if the requested number of digits is + not enough to represent all the frame numbers. + """ + ds = request.getfixturevalue(valid_dataset_str) + min_required_digits = _get_min_required_digits_in_ds(ds) + + with pytest.raises(ValueError) as error: + save_bboxes._check_frame_required_digits( + ds=ds, frame_n_digits=requested_n_digits + ) + + assert str(error.value) == ( + "The requested number of digits to represent the frame " + "number cannot be used to represent all the frame numbers." + f"Got {requested_n_digits}, but the maximum frame number has " + f"{min_required_digits} digits" + ) + + +@pytest.mark.parametrize( + "list_individuals, expected_track_id", + [ + (["id1", "id2", "id3"], [1, 2, 3]), + (["id1", "id3", "id2"], [1, 3, 2]), + (["id-1", "id-2", "id-3"], [1, 2, 3]), + (["id_1", "id_2", "id_3"], [1, 2, 3]), + (["id101", "id2", "id333"], [101, 2, 333]), + (["mouse_0_id1", "mouse_0_id2"], [1, 2]), + ], + ids=[ + "sorted", + "unsorted", + "dashes", + "underscores", + "multiple_digits", + "middle_and_end_digits", + ], +) +def test_get_map_individuals_to_track_ids_from_individuals_names( + list_individuals, expected_track_id +): + """Test the mapping individuals to track IDs if the track ID is + extracted from the individuals' names. + """ + # Map individuals to track IDs + map_individual_to_track_id = _get_map_individuals_to_track_ids( + list_individuals, extract_track_id_from_individuals=True + ) + + # Check values are as expected + assert [ + map_individual_to_track_id[individual] + for individual in list_individuals + ] == expected_track_id + + +@pytest.mark.parametrize( + "list_individuals, expected_track_id", + [ + (["A", "B", "C"], [0, 1, 2]), + (["C", "B", "A"], [2, 1, 0]), + (["id99", "id88", "id77"], [2, 1, 0]), + ], + ids=["sorted", "unsorted", "should_ignore_digits"], +) +def test_get_map_individuals_to_track_ids_factorised( + list_individuals, expected_track_id +): + """Test the mapping individuals to track IDs if the track ID is + factorised from the sorted individuals' names. + """ + # Map individuals to track IDs + map_individual_to_track_id = _get_map_individuals_to_track_ids( + list_individuals, extract_track_id_from_individuals=False + ) + + # Check values are as expected + assert [ + map_individual_to_track_id[individual] + for individual in list_individuals + ] == expected_track_id + + +@pytest.mark.parametrize( + "list_individuals, expected_error_message", + [ + ( + ["mouse_1_id0", "mouse_2_id0"], + ( + "Could not extract a unique track ID for all individuals. " + "Expected 2 unique track IDs, but got 1." + ), + ), + ( + ["mouse_id1.0", "mouse_id2.0"], + ( + "Could not extract a unique track ID for all individuals. " + "Expected 2 unique track IDs, but got 1." + ), + ), + (["A_1", "B_2", "C"], "Could not extract track ID from C."), + ], + ids=["id_clash_1", "id_clash_2", "individuals_without_digits"], +) +def test_get_map_individuals_to_track_ids_error( + list_individuals, expected_error_message +): + """Test that the appropriate error is raised if extracting track IDs + from the individuals' names fails. + """ + with pytest.raises(ValueError) as error: + _get_map_individuals_to_track_ids( + list_individuals, + extract_track_id_from_individuals=True, + ) + + # Check that the error message is as expected + assert str(error.value) == expected_error_message + + +@pytest.mark.parametrize( + "confidence", + [None, 0.5], + ids=["without_confidence", "with_confidence"], +) +@pytest.mark.parametrize( + "image_size", + [None, 100], + ids=["without_image_size", "with_image_size"], +) +@pytest.mark.parametrize( + "img_filename_template", + ["{:05d}.png", "{:03d}.jpg", "frame_{:03d}.jpg"], + ids=["png_extension", "jpg_extension", "frame_prefix"], +) +def test_write_single_row( + mock_csv_writer, + confidence, + image_size, + img_filename_template, +): + """Test writing a single row of the VIA tracks .csv file.""" + # Fixed input values + frame, track_id, region_count, region_id, xy_values, wh_values = ( + 1, + 0, + 88, + 0, + np.array([100, 200]), + np.array([50, 30]), + ) + + # Write single row of VIA tracks .csv file + with patch("csv.writer", return_value=mock_csv_writer): + row = _write_single_row( + writer=mock_csv_writer, + xy_values=xy_values, + wh_values=wh_values, + confidence=confidence, + track_id=track_id, + region_count=region_count, + region_id=region_id, + img_filename=img_filename_template.format(frame), + image_size=image_size, + ) + mock_csv_writer.writerow.assert_called_with(row) + + # Compute expected region shape attributes + expected_region_shape_attrs_dict = { + "name": "rect", + "x": float(xy_values[0] - wh_values[0] / 2), + "y": float(xy_values[1] - wh_values[1] / 2), + "width": float(wh_values[0]), + "height": float(wh_values[1]), + } + expected_region_shape_attributes = json.dumps( + expected_region_shape_attrs_dict + ) + + # Compute expected region attributes + expected_region_attributes_dict = { + "track": int(track_id), + } + if confidence is not None: + expected_region_attributes_dict["confidence"] = confidence + + expected_region_attributes = json.dumps(expected_region_attributes_dict) + + # Check values are as expected + assert row[0] == img_filename_template.format(frame) + assert row[1] == (image_size if image_size is not None else 0) + assert row[2] == '{"shot": 0}' # placeholder value + assert row[3] == region_count + assert row[4] == region_id + assert row[5] == expected_region_shape_attributes + assert row[6] == expected_region_attributes + + +def test_number_of_quotes_in_via_tracks_csv_file( + valid_bboxes_dataset, tmp_path +): + """Test the literal string for two lines of the VIA tracks .csv file. + + This is to verify that the quotes in the output VIA tracks .csv file are + as expected. Without the required double quotes, the file won't be + importable in the VIA annotation tool. + + The VIA tracks .csv file format has: + - dictionary-like items wrapped around single double-quotes (") + - keys in these dictionaries wrapped around double double-quotes ("") + + See an example of the VIA tracks .csv file format at: + https://www.robots.ox.ac.uk/~vgg/software/via/docs/face_track_annotation.html + """ + # Save VIA tracks .csv file + output_path = tmp_path / "test_valid_dataset.csv" + save_bboxes.to_via_tracks_file(valid_bboxes_dataset, output_path) + + # Read text file + with open(output_path) as file: + lines = file.readlines() + + # Check a line with bbox id_0 + assert lines[1] == ( + "00.png," # filename + "0," # filesize + '"{""shot"": 0}",' # file attributes + "2," # region_count + "0," # region_id + '"{""name"": ""rect"", ' # region shape attributes + '""x"": -30.0, ""y"": -20.0, ""width"": 60.0, ""height"": 40.0}",' + '"{""track"": 0, ""confidence"": 0.9}"\n' # region attributes + ) + + # Check a line with bbox id_1 + assert lines[-1] == ( + "09.png," # filename + "0," # filesize + '"{""shot"": 0}",' # file attributes + "2," # region_count + "1," # region_id + '"{""name"": ""rect"", ' # region shape attributes + '""x"": -21.0, ""y"": -29.0, ""width"": 60.0, ""height"": 40.0}",' + '"{""track"": 1, ""confidence"": 0.9}"\n' # region attributes + ) + + +@pytest.mark.parametrize( + "via_file_path", + [ + pytest.DATA_PATHS.get("VIA_multiple-crabs_5-frames_labels.csv"), + pytest.DATA_PATHS.get("VIA_single-crab_MOCA-crab-1.csv"), + ], +) +def test_to_via_tracks_file_is_recoverable(via_file_path, tmp_path): + """Test that an exported VIA tracks .csv file can be loaded back into + the a dataset that matches the original one. + """ + # Load a bboxes dataset from a VIA tracks .csv file + original_ds = load_bboxes.from_via_tracks_file( + via_file_path, use_frame_numbers_from_file=True + ) + + # Export the dataset + output_path = tmp_path / "test_via_file.csv" + save_bboxes.to_via_tracks_file( + original_ds, + output_path, + extract_track_id_from_individuals=True, + ) + + # Load the exported file + recovered_ds = load_bboxes.from_via_tracks_file( + output_path, use_frame_numbers_from_file=True + ) + + # Compare the original and recovered datasets + xr.testing.assert_equal(original_ds, recovered_ds) diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_io/test_save_poses.py similarity index 100% rename from tests/test_unit/test_save_poses.py rename to tests/test_unit/test_io/test_save_poses.py diff --git a/tests/test_unit/test_io/test_utils.py b/tests/test_unit/test_io/test_utils.py new file mode 100644 index 000000000..98406b017 --- /dev/null +++ b/tests/test_unit/test_io/test_utils.py @@ -0,0 +1,117 @@ +"""Unit tests for the movement.io.utils module.""" + +import stat +from pathlib import Path + +import pytest + +from movement.io.utils import _validate_file_path +from movement.validators.files import ValidFile + + +@pytest.fixture +def sample_file_path(): + """Return a factory of file paths with a given file extension suffix.""" + + def _sample_file_path(tmp_path: Path, suffix: str): + """Return a path for a file under the pytest temporary directory + with the given file extension. + """ + file_path = tmp_path / f"test.{suffix}" + return file_path + + return _sample_file_path + + +@pytest.mark.parametrize("suffix", [".txt", ".csv"]) +def test_validate_file_path_valid_file(sample_file_path, tmp_path, suffix): + """Test file path validation with a correct file.""" + file_path = sample_file_path(tmp_path, suffix) + validated_file = _validate_file_path(file_path, [suffix]) + + assert isinstance(validated_file, ValidFile) + assert validated_file.path == file_path + + +@pytest.mark.parametrize("suffix", [".txt", ".csv"]) +def test_validate_file_path_invalid_permission( + sample_file_path, tmp_path, suffix +): + """Test file path validation with a file that has invalid permissions. + + We use the following permissions: + - S_IRUSR: Read permission for owner + - S_IRGRP: Read permission for group + - S_IROTH: Read permission for others + """ + # Create a sample file with read-only permission + file_path = sample_file_path(tmp_path, suffix) + file_path.touch() + file_path.chmod(stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + + # Try to validate the file path + # (should raise an OSError since we require write permissions) + with pytest.raises(OSError): + _validate_file_path(file_path, [suffix]) + + +@pytest.mark.parametrize("suffix", [".txt", ".csv"]) +def test_validate_file_path_file_exists(sample_file_path, tmp_path, suffix): + """Test file path validation with a file that already exists. + + We use the following permissions to create a file with the right + permissions: + - S_IRUSR: Read permission for owner + - S_IWUSR: Write permission for owner + - S_IRGRP: Read permission for group + - S_IWGRP: Write permission for group + - S_IROTH: Read permission for others + - S_IWOTH: Write permission for others + + We include both read and write permissions because in real-world + scenarios it's very rare to have a file that is writable but not readable. + """ + # Create a sample file with read and write permissions + file_path = sample_file_path(tmp_path, suffix) + file_path.touch() + file_path.chmod( + stat.S_IRUSR + | stat.S_IWUSR + | stat.S_IRGRP + | stat.S_IWGRP + | stat.S_IROTH + | stat.S_IWOTH + ) + + # Try to validate the file path + # (should raise an OSError since the file already exists) + with pytest.raises(OSError): + _validate_file_path(file_path, [suffix]) + + +@pytest.mark.parametrize("invalid_suffix", [".foo", "", None]) +def test_validate_file_path_invalid_suffix( + sample_file_path, tmp_path, invalid_suffix +): + """Test file path validation with an invalid file suffix.""" + # Create a file path with an invalid suffix + file_path = sample_file_path(tmp_path, invalid_suffix) + + # Try to validate using a .txt suffix + with pytest.raises(ValueError): + _validate_file_path(file_path, [".txt"]) + + +@pytest.mark.parametrize("suffix", [".txt", ".csv"]) +def test_validate_file_path_multiple_suffixes( + sample_file_path, tmp_path, suffix +): + """Test file path validation with multiple valid suffixes.""" + # Create a valid txt file path + file_path = sample_file_path(tmp_path, suffix) + + # Validate using multiple valid suffixes + validated_file = _validate_file_path(file_path, [".txt", ".csv"]) + + assert isinstance(validated_file, ValidFile) + assert validated_file.path == file_path