|
7 | 7 | import cv2
|
8 | 8 | import numpy as np
|
9 | 9 | from tqdm import tqdm
|
10 |
| -from typing import List, Optional, Union |
| 10 | +from typing import List, Optional, Union, Tuple, Any |
11 | 11 | import shapely
|
12 | 12 | import pyproj
|
13 | 13 | import rasterio
|
@@ -3520,3 +3520,168 @@ def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str:
|
3520 | 3520 | geotiff_to_jpg(geotiff_path, output_path)
|
3521 | 3521 |
|
3522 | 3522 | return output_folder
|
| 3523 | + |
| 3524 | + |
| 3525 | +def region_groups( |
| 3526 | + image: Union[str, "xr.DataArray", np.ndarray], |
| 3527 | + connectivity: int = 1, |
| 3528 | + min_size: int = 10, |
| 3529 | + max_size: Optional[int] = None, |
| 3530 | + threshold: Optional[int] = None, |
| 3531 | + properties: Optional[List[str]] = None, |
| 3532 | + out_csv: Optional[str] = None, |
| 3533 | + out_vector: Optional[str] = None, |
| 3534 | + out_image: Optional[str] = None, |
| 3535 | + **kwargs: Any, |
| 3536 | +) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]: |
| 3537 | + """ |
| 3538 | + Segment regions in an image and filter them based on size. |
| 3539 | +
|
| 3540 | + Args: |
| 3541 | + image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file |
| 3542 | + path, xarray DataArray, or numpy array. |
| 3543 | + connectivity (int, optional): Connectivity for labeling. Defaults to 1 |
| 3544 | + for 4-connectivity. Use 2 for 8-connectivity. |
| 3545 | + min_size (int, optional): Minimum size of regions to keep. Defaults to 10. |
| 3546 | + max_size (Optional[int], optional): Maximum size of regions to keep. |
| 3547 | + Defaults to None. |
| 3548 | + threshold (Optional[int], optional): Threshold for filling holes. |
| 3549 | + Defaults to None, which is equal to min_size. |
| 3550 | + properties (Optional[List[str]], optional): List of properties to measure. |
| 3551 | + See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops |
| 3552 | + Defaults to None. |
| 3553 | + out_csv (Optional[str], optional): Path to save the properties as a CSV file. |
| 3554 | + Defaults to None. |
| 3555 | + out_vector (Optional[str], optional): Path to save the vector file. |
| 3556 | + Defaults to None. |
| 3557 | + out_image (Optional[str], optional): Path to save the output image. |
| 3558 | + Defaults to None. |
| 3559 | +
|
| 3560 | + Returns: |
| 3561 | + Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame. |
| 3562 | + """ |
| 3563 | + import rioxarray as rxr |
| 3564 | + import xarray as xr |
| 3565 | + from skimage import measure |
| 3566 | + import pandas as pd |
| 3567 | + import scipy.ndimage as ndi |
| 3568 | + |
| 3569 | + if isinstance(image, str): |
| 3570 | + ds = rxr.open_rasterio(image) |
| 3571 | + da = ds.sel(band=1) |
| 3572 | + array = da.values.squeeze() |
| 3573 | + elif isinstance(image, xr.DataArray): |
| 3574 | + da = image |
| 3575 | + array = image.values.squeeze() |
| 3576 | + elif isinstance(image, np.ndarray): |
| 3577 | + array = image |
| 3578 | + else: |
| 3579 | + raise ValueError( |
| 3580 | + "The input image must be a file path, xarray DataArray, or numpy array." |
| 3581 | + ) |
| 3582 | + |
| 3583 | + if threshold is None: |
| 3584 | + threshold = min_size |
| 3585 | + |
| 3586 | + if properties is None: |
| 3587 | + properties = [ |
| 3588 | + "label", |
| 3589 | + "area", |
| 3590 | + "area_bbox", |
| 3591 | + "area_convex", |
| 3592 | + "area_filled", |
| 3593 | + "axis_major_length", |
| 3594 | + "axis_minor_length", |
| 3595 | + "eccentricity", |
| 3596 | + "equivalent_diameter_area", |
| 3597 | + "extent", |
| 3598 | + "orientation", |
| 3599 | + "perimeter", |
| 3600 | + "solidity", |
| 3601 | + ] |
| 3602 | + |
| 3603 | + label_image = measure.label(array, connectivity=connectivity) |
| 3604 | + props = measure.regionprops_table(label_image, properties=properties) |
| 3605 | + |
| 3606 | + df = pd.DataFrame(props) |
| 3607 | + |
| 3608 | + # Get the labels of regions with area smaller than the threshold |
| 3609 | + small_regions = df[df["area"] < min_size]["label"].values |
| 3610 | + # Set the corresponding labels in the label_image to zero |
| 3611 | + for region_label in small_regions: |
| 3612 | + label_image[label_image == region_label] = 0 |
| 3613 | + |
| 3614 | + if max_size is not None: |
| 3615 | + large_regions = df[df["area"] > max_size]["label"].values |
| 3616 | + for region_label in large_regions: |
| 3617 | + label_image[label_image == region_label] = 0 |
| 3618 | + |
| 3619 | + # Find the background (holes) which are zeros |
| 3620 | + holes = label_image == 0 |
| 3621 | + |
| 3622 | + # Label the holes (connected components in the background) |
| 3623 | + labeled_holes, _ = ndi.label(holes) |
| 3624 | + |
| 3625 | + # Measure properties of the labeled holes, including area and bounding box |
| 3626 | + hole_props = measure.regionprops(labeled_holes) |
| 3627 | + |
| 3628 | + # Loop through each hole and fill it if it is smaller than the threshold |
| 3629 | + for prop in hole_props: |
| 3630 | + if prop.area < threshold: |
| 3631 | + # Get the coordinates of the small hole |
| 3632 | + coords = prop.coords |
| 3633 | + |
| 3634 | + # Find the surrounding region's ID (non-zero value near the hole) |
| 3635 | + surrounding_region_values = [] |
| 3636 | + for coord in coords: |
| 3637 | + x, y = coord |
| 3638 | + # Get a 3x3 neighborhood around the hole pixel |
| 3639 | + neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2] |
| 3640 | + # Exclude the hole pixels (zeros) and get region values |
| 3641 | + region_values = neighbors[neighbors != 0] |
| 3642 | + if region_values.size > 0: |
| 3643 | + surrounding_region_values.append( |
| 3644 | + region_values[0] |
| 3645 | + ) # Take the first non-zero value |
| 3646 | + |
| 3647 | + if surrounding_region_values: |
| 3648 | + # Fill the hole with the mode (most frequent) of the surrounding region values |
| 3649 | + fill_value = max( |
| 3650 | + set(surrounding_region_values), key=surrounding_region_values.count |
| 3651 | + ) |
| 3652 | + label_image[coords[:, 0], coords[:, 1]] = fill_value |
| 3653 | + |
| 3654 | + label_image, num_labels = measure.label( |
| 3655 | + label_image, connectivity=connectivity, return_num=True |
| 3656 | + ) |
| 3657 | + props = measure.regionprops_table(label_image, properties=properties) |
| 3658 | + |
| 3659 | + df = pd.DataFrame(props) |
| 3660 | + df["elongation"] = df["axis_major_length"] / df["axis_minor_length"] |
| 3661 | + |
| 3662 | + dtype = "uint8" |
| 3663 | + if num_labels > 255 and num_labels <= 65535: |
| 3664 | + dtype = "uint16" |
| 3665 | + elif num_labels > 65535: |
| 3666 | + dtype = "uint32" |
| 3667 | + |
| 3668 | + if out_csv is not None: |
| 3669 | + df.to_csv(out_csv, index=False) |
| 3670 | + |
| 3671 | + if isinstance(image, np.ndarray): |
| 3672 | + return label_image, df |
| 3673 | + else: |
| 3674 | + da.values = label_image |
| 3675 | + if out_image is not None: |
| 3676 | + da.rio.to_raster(out_image, dtype=dtype) |
| 3677 | + if out_vector is not None: |
| 3678 | + tmp_vector = temp_file_path(".gpkg") |
| 3679 | + raster_to_vector(out_image, tmp_vector) |
| 3680 | + gdf = gpd.read_file(tmp_vector) |
| 3681 | + gdf["label"] = gdf["value"].astype(int) |
| 3682 | + gdf.drop(columns=["value"], inplace=True) |
| 3683 | + gdf2 = pd.merge(gdf, df, on="label", how="left") |
| 3684 | + gdf2.to_file(out_vector) |
| 3685 | + gdf2.sort_values("label", inplace=True) |
| 3686 | + df = gdf2 |
| 3687 | + return da, df |
0 commit comments