From 5fc1320d719ab6d8f9e42b4f17b8389e0bdb49bc Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Thu, 17 Apr 2025 10:21:37 -0400 Subject: [PATCH 1/2] Add crs and transform for array input --- samgeo/common.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/samgeo/common.py b/samgeo/common.py index 41965dc0..0ebc9e02 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -1458,12 +1458,16 @@ def array_to_image( array = cv2.imread(array) array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB) - if output.endswith(".tif") and source is not None: - with rasterio.open(source) as src: - crs = src.crs - transform = src.transform - if compress is None: - compress = src.compression + if output.endswith(".tif"): + if source is not None: + with rasterio.open(source) as src: + crs = src.crs + transform = src.transform + if compress is None: + compress = src.compression + else: + crs = kwargs.get("crs", None) + transform = kwargs.get("transform", None) # Determine the minimum and maximum values in the array From 5eb0a3f71af36a2aaa1a568280bde81c745c65c7 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Thu, 17 Apr 2025 10:37:20 -0400 Subject: [PATCH 2/2] Update show_anns --- samgeo/samgeo.py | 10 +++++----- samgeo/samgeo2.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/samgeo/samgeo.py b/samgeo/samgeo.py index 9d0d7e01..34f0d92e 100644 --- a/samgeo/samgeo.py +++ b/samgeo/samgeo.py @@ -425,11 +425,11 @@ def show_anns( img[m] = color_mask ax.imshow(img) - if "dpi" not in kwargs: - kwargs["dpi"] = 100 + # if "dpi" not in kwargs: + # kwargs["dpi"] = 100 - if "bbox_inches" not in kwargs: - kwargs["bbox_inches"] = "tight" + # if "bbox_inches" not in kwargs: + # kwargs["bbox_inches"] = "tight" plt.axis(axis) @@ -442,7 +442,7 @@ def show_anns( ) else: array = self.annotations - array_to_image(array, output, self.source) + array_to_image(array, output, self.source, **kwargs) def set_image(self, image, image_format="RGB"): """Set the input image as a numpy array. diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index a0bd88ca..e6004af5 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -456,11 +456,11 @@ def show_anns( img[m] = color_mask ax.imshow(img) - if "dpi" not in kwargs: - kwargs["dpi"] = 100 + # if "dpi" not in kwargs: + # kwargs["dpi"] = 100 - if "bbox_inches" not in kwargs: - kwargs["bbox_inches"] = "tight" + # if "bbox_inches" not in kwargs: + # kwargs["bbox_inches"] = "tight" plt.axis(axis) @@ -473,7 +473,7 @@ def show_anns( ) else: array = self.annotations - common.array_to_image(array, output, self.source) + common.array_to_image(array, output, self.source, **kwargs) @torch.no_grad() def set_image(