Skip to content

Commit d8e2a30

Browse files
authored
Added clip to rgb_image_from_tensor
1 parent ca08e7e commit d8e2a30

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pytorch_toolbelt/utils/torch_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def tensor_from_mask_image(mask: np.ndarray) -> torch.Tensor:
184184

185185

186186
def rgb_image_from_tensor(
187-
image: torch.Tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, dtype=np.uint8
187+
image: torch.Tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), min_pixel_value=0.0, max_pixel_value=255.0, dtype=np.uint8
188188
) -> np.ndarray:
189189
"""
190190
Convert numpy image (RGB, BGR, Grayscale, SAR, Mask image, etc.) to tensor
@@ -195,8 +195,9 @@ def rgb_image_from_tensor(
195195
image = np.moveaxis(to_numpy(image), 0, -1)
196196
mean = to_numpy(mean)
197197
std = to_numpy(std)
198-
rgb_image = (max_pixel_value * (image * std + mean)).astype(dtype)
199-
return rgb_image
198+
rgb_image = (max_pixel_value * (image * std + mean))
199+
rgb_image = np.clip(rgb_image, a_min=min_pixel_value, a_max=max_pixel_value)
200+
return rgb_image.astype(dtype)
200201

201202

202203
def mask_from_tensor(mask: torch.Tensor, squeeze_single_channel=False, dtype=None) -> np.ndarray:

0 commit comments

Comments
 (0)