Skip to content

Commit f375e92

Browse files
authored
Add region_groups function (#334)
1 parent 2b4c015 commit f375e92

File tree

3 files changed

+223
-1
lines changed

3 files changed

+223
-1
lines changed

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ patool
1111
pycocotools
1212
pyproj
1313
rasterio
14+
rioxarray
1415
sam2
16+
scikit-image
17+
scikit-learn
1518
segment-anything-hq
1619
segment-anything-py
1720
timm

samgeo/common.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import cv2
88
import numpy as np
99
from tqdm import tqdm
10-
from typing import List, Optional, Union
10+
from typing import List, Optional, Union, Tuple, Any
1111
import shapely
1212
import pyproj
1313
import rasterio
@@ -3520,3 +3520,168 @@ def geotiff_to_jpg_batch(input_folder: str, output_folder: str = None) -> str:
35203520
geotiff_to_jpg(geotiff_path, output_path)
35213521

35223522
return output_folder
3523+
3524+
3525+
def region_groups(
3526+
image: Union[str, "xr.DataArray", np.ndarray],
3527+
connectivity: int = 1,
3528+
min_size: int = 10,
3529+
max_size: Optional[int] = None,
3530+
threshold: Optional[int] = None,
3531+
properties: Optional[List[str]] = None,
3532+
out_csv: Optional[str] = None,
3533+
out_vector: Optional[str] = None,
3534+
out_image: Optional[str] = None,
3535+
**kwargs: Any,
3536+
) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]:
3537+
"""
3538+
Segment regions in an image and filter them based on size.
3539+
3540+
Args:
3541+
image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
3542+
path, xarray DataArray, or numpy array.
3543+
connectivity (int, optional): Connectivity for labeling. Defaults to 1
3544+
for 4-connectivity. Use 2 for 8-connectivity.
3545+
min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
3546+
max_size (Optional[int], optional): Maximum size of regions to keep.
3547+
Defaults to None.
3548+
threshold (Optional[int], optional): Threshold for filling holes.
3549+
Defaults to None, which is equal to min_size.
3550+
properties (Optional[List[str]], optional): List of properties to measure.
3551+
See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
3552+
Defaults to None.
3553+
out_csv (Optional[str], optional): Path to save the properties as a CSV file.
3554+
Defaults to None.
3555+
out_vector (Optional[str], optional): Path to save the vector file.
3556+
Defaults to None.
3557+
out_image (Optional[str], optional): Path to save the output image.
3558+
Defaults to None.
3559+
3560+
Returns:
3561+
Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
3562+
"""
3563+
import rioxarray as rxr
3564+
import xarray as xr
3565+
from skimage import measure
3566+
import pandas as pd
3567+
import scipy.ndimage as ndi
3568+
3569+
if isinstance(image, str):
3570+
ds = rxr.open_rasterio(image)
3571+
da = ds.sel(band=1)
3572+
array = da.values.squeeze()
3573+
elif isinstance(image, xr.DataArray):
3574+
da = image
3575+
array = image.values.squeeze()
3576+
elif isinstance(image, np.ndarray):
3577+
array = image
3578+
else:
3579+
raise ValueError(
3580+
"The input image must be a file path, xarray DataArray, or numpy array."
3581+
)
3582+
3583+
if threshold is None:
3584+
threshold = min_size
3585+
3586+
if properties is None:
3587+
properties = [
3588+
"label",
3589+
"area",
3590+
"area_bbox",
3591+
"area_convex",
3592+
"area_filled",
3593+
"axis_major_length",
3594+
"axis_minor_length",
3595+
"eccentricity",
3596+
"equivalent_diameter_area",
3597+
"extent",
3598+
"orientation",
3599+
"perimeter",
3600+
"solidity",
3601+
]
3602+
3603+
label_image = measure.label(array, connectivity=connectivity)
3604+
props = measure.regionprops_table(label_image, properties=properties)
3605+
3606+
df = pd.DataFrame(props)
3607+
3608+
# Get the labels of regions with area smaller than the threshold
3609+
small_regions = df[df["area"] < min_size]["label"].values
3610+
# Set the corresponding labels in the label_image to zero
3611+
for region_label in small_regions:
3612+
label_image[label_image == region_label] = 0
3613+
3614+
if max_size is not None:
3615+
large_regions = df[df["area"] > max_size]["label"].values
3616+
for region_label in large_regions:
3617+
label_image[label_image == region_label] = 0
3618+
3619+
# Find the background (holes) which are zeros
3620+
holes = label_image == 0
3621+
3622+
# Label the holes (connected components in the background)
3623+
labeled_holes, _ = ndi.label(holes)
3624+
3625+
# Measure properties of the labeled holes, including area and bounding box
3626+
hole_props = measure.regionprops(labeled_holes)
3627+
3628+
# Loop through each hole and fill it if it is smaller than the threshold
3629+
for prop in hole_props:
3630+
if prop.area < threshold:
3631+
# Get the coordinates of the small hole
3632+
coords = prop.coords
3633+
3634+
# Find the surrounding region's ID (non-zero value near the hole)
3635+
surrounding_region_values = []
3636+
for coord in coords:
3637+
x, y = coord
3638+
# Get a 3x3 neighborhood around the hole pixel
3639+
neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2]
3640+
# Exclude the hole pixels (zeros) and get region values
3641+
region_values = neighbors[neighbors != 0]
3642+
if region_values.size > 0:
3643+
surrounding_region_values.append(
3644+
region_values[0]
3645+
) # Take the first non-zero value
3646+
3647+
if surrounding_region_values:
3648+
# Fill the hole with the mode (most frequent) of the surrounding region values
3649+
fill_value = max(
3650+
set(surrounding_region_values), key=surrounding_region_values.count
3651+
)
3652+
label_image[coords[:, 0], coords[:, 1]] = fill_value
3653+
3654+
label_image, num_labels = measure.label(
3655+
label_image, connectivity=connectivity, return_num=True
3656+
)
3657+
props = measure.regionprops_table(label_image, properties=properties)
3658+
3659+
df = pd.DataFrame(props)
3660+
df["elongation"] = df["axis_major_length"] / df["axis_minor_length"]
3661+
3662+
dtype = "uint8"
3663+
if num_labels > 255 and num_labels <= 65535:
3664+
dtype = "uint16"
3665+
elif num_labels > 65535:
3666+
dtype = "uint32"
3667+
3668+
if out_csv is not None:
3669+
df.to_csv(out_csv, index=False)
3670+
3671+
if isinstance(image, np.ndarray):
3672+
return label_image, df
3673+
else:
3674+
da.values = label_image
3675+
if out_image is not None:
3676+
da.rio.to_raster(out_image, dtype=dtype)
3677+
if out_vector is not None:
3678+
tmp_vector = temp_file_path(".gpkg")
3679+
raster_to_vector(out_image, tmp_vector)
3680+
gdf = gpd.read_file(tmp_vector)
3681+
gdf["label"] = gdf["value"].astype(int)
3682+
gdf.drop(columns=["value"], inplace=True)
3683+
gdf2 = pd.merge(gdf, df, on="label", how="left")
3684+
gdf2.to_file(out_vector)
3685+
gdf2.sort_values("label", inplace=True)
3686+
df = gdf2
3687+
return da, df

samgeo/samgeo2.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,3 +1514,57 @@ def raster_to_vector(self, raster, vector, simplify_tolerance=None, **kwargs):
15141514
common.raster_to_vector(
15151515
raster, vector, simplify_tolerance=simplify_tolerance, **kwargs
15161516
)
1517+
1518+
def region_groups(
1519+
self,
1520+
image: Union[str, "xr.DataArray", np.ndarray],
1521+
connectivity: int = 1,
1522+
min_size: int = 10,
1523+
max_size: Optional[int] = None,
1524+
threshold: Optional[int] = None,
1525+
properties: Optional[List[str]] = None,
1526+
out_csv: Optional[str] = None,
1527+
out_vector: Optional[str] = None,
1528+
out_image: Optional[str] = None,
1529+
**kwargs: Any,
1530+
) -> Union[
1531+
Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]
1532+
]:
1533+
"""
1534+
Segment regions in an image and filter them based on size.
1535+
1536+
Args:
1537+
image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
1538+
path, xarray DataArray, or numpy array.
1539+
connectivity (int, optional): Connectivity for labeling. Defaults to 1
1540+
for 4-connectivity. Use 2 for 8-connectivity.
1541+
min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
1542+
max_size (Optional[int], optional): Maximum size of regions to keep.
1543+
Defaults to None.
1544+
threshold (Optional[int], optional): Threshold for filling holes.
1545+
Defaults to None, which is equal to min_size.
1546+
properties (Optional[List[str]], optional): List of properties to measure.
1547+
See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
1548+
Defaults to None.
1549+
out_csv (Optional[str], optional): Path to save the properties as a CSV file.
1550+
Defaults to None.
1551+
out_vector (Optional[str], optional): Path to save the vector file.
1552+
Defaults to None.
1553+
out_image (Optional[str], optional): Path to save the output image.
1554+
Defaults to None.
1555+
1556+
Returns:
1557+
Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
1558+
"""
1559+
return common.region_groups(
1560+
image,
1561+
connectivity=connectivity,
1562+
min_size=min_size,
1563+
max_size=max_size,
1564+
threshold=threshold,
1565+
properties=properties,
1566+
out_csv=out_csv,
1567+
out_vector=out_vector,
1568+
out_image=out_image,
1569+
**kwargs,
1570+
)

0 commit comments

Comments
 (0)