Skip to content

Commit 02d92fa

Browse files
authored
Allow dicom files to work with process_image (#171)
* allow dicom files to work with process_image
1 parent 7587520 commit 02d92fa

File tree

2 files changed

+85
-57
lines changed

2 files changed

+85
-57
lines changed

scripts/process_image.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python
22
# coding: utf-8
33

4-
import os,sys
5-
sys.path.insert(0,"..")
4+
import os, sys
5+
6+
sys.path.insert(0, "..")
67
from glob import glob
78
import matplotlib.pyplot as plt
89
import numpy as np
@@ -17,33 +18,22 @@
1718
import torchxrayvision as xrv
1819

1920
parser = argparse.ArgumentParser()
20-
parser.add_argument('-f', type=str, default="", help='')
21-
parser.add_argument('img_path', type=str)
22-
parser.add_argument('-weights', type=str,default="densenet121-res224-all")
23-
parser.add_argument('-feats', default=False, help='', action='store_true')
24-
parser.add_argument('-cuda', default=False, help='', action='store_true')
25-
parser.add_argument('-resize', default=False, help='', action='store_true')
21+
parser.add_argument("-f", type=str, default="", help="")
22+
parser.add_argument("img_path", type=str)
23+
parser.add_argument("-weights", type=str, default="densenet121-res224-all")
24+
parser.add_argument("-feats", default=False, help="", action="store_true")
25+
parser.add_argument("-cuda", default=False, help="", action="store_true")
26+
parser.add_argument("-resize", default=False, help="", action="store_true")
2627

2728
cfg = parser.parse_args()
2829

29-
30-
img = skimage.io.imread(cfg.img_path)
31-
img = xrv.datasets.normalize(img, 255)
32-
33-
# Check that images are 2D arrays
34-
if len(img.shape) > 2:
35-
img = img[:, :, 0]
36-
if len(img.shape) < 2:
37-
print("error, dimension lower than 2 for image")
38-
39-
# Add color channel
40-
img = img[None, :, :]
41-
30+
img = xrv.utils.load_image(cfg.img_path)
4231

4332
# the models will resize the input to the correct size so this is optional.
4433
if cfg.resize:
45-
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),
46-
xrv.datasets.XRayResizer(224)])
34+
transform = torchvision.transforms.Compose(
35+
[xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224)]
36+
)
4737
else:
4838
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
4939

@@ -58,19 +48,19 @@
5848
if cfg.cuda:
5949
img = img.cuda()
6050
model = model.cuda()
61-
51+
6252
if cfg.feats:
6353
feats = model.features(img)
6454
feats = F.relu(feats, inplace=True)
6555
feats = F.adaptive_avg_pool2d(feats, (1, 1))
6656
output["feats"] = list(feats.cpu().detach().numpy().reshape(-1))
6757

6858
preds = model(img).cpu()
69-
output["preds"] = dict(zip(xrv.datasets.default_pathologies,preds[0].detach().numpy()))
70-
59+
output["preds"] = dict(
60+
zip(xrv.datasets.default_pathologies, preds[0].detach().numpy())
61+
)
62+
7163
if cfg.feats:
7264
print(output)
7365
else:
7466
pprint.pprint(output)
75-
76-

torchxrayvision/utils.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
from os import PathLike
99
from numpy import ndarray
1010
import warnings
11-
from tqdm.autonotebook import tqdm
11+
from tqdm.auto import tqdm
1212

1313

1414
def get_cache_dir():
1515
return os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data/"))
1616

17+
1718
def in_notebook():
1819
try:
1920
from IPython import get_ipython
20-
if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
21+
22+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
2123
return False
2224
except ImportError:
2325
return False
@@ -28,31 +30,37 @@ def in_notebook():
2830

2931
# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
3032
def download(url: str, filename: str):
31-
with open(filename, 'wb') as f:
33+
with open(filename, "wb") as f:
3234
response = requests.get(url, stream=True)
33-
total = response.headers.get('content-length')
35+
total = response.headers.get("content-length")
3436

3537
if total is None:
3638
f.write(response.content)
3739
else:
3840
downloaded = 0
3941
total = int(total)
40-
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
42+
for data in response.iter_content(
43+
chunk_size=max(int(total / 1000), 1024 * 1024)
44+
):
4145
downloaded += len(data)
4246
f.write(data)
4347
done = int(50 * downloaded / total)
44-
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
48+
sys.stdout.write("\r[{}{}]".format("█" * done, "." * (50 - done)))
4549
sys.stdout.flush()
46-
sys.stdout.write('\n')
50+
sys.stdout.write("\n")
4751

4852

4953
def normalize(img, maxval, reshape=False):
5054
"""Scales images to be roughly [-1024 1024]."""
5155

5256
if img.max() > maxval:
53-
raise Exception("max image value ({}) higher than expected bound ({}).".format(img.max(), maxval))
57+
raise ValueError(
58+
"max image value ({}) higher than expected bound ({}).".format(
59+
img.max(), maxval
60+
)
61+
)
5462

55-
img = (2 * (img.astype(np.float32) / maxval) - 1.) * 1024
63+
img = (2 * (img.astype(np.float32) / maxval) - 1.0) * 1024
5664

5765
if reshape:
5866
# Check that images are 2D arrays
@@ -70,6 +78,13 @@ def normalize(img, maxval, reshape=False):
7078
def load_image(fname: str):
7179
"""Load an image from a file and normalize it between -1024 and 1024. Assumes 8-bits per pixel."""
7280

81+
with open(fname, "rb") as f:
82+
# Read the first 132 bytes (128 preamble + 4 for "DICM")
83+
header = f.read(132)
84+
# Check if the file is long enough and has "DICM" at position 128
85+
if len(header) >= 132 and header[128:132] == b"DICM":
86+
return read_xray_dcm(fname)[None, ...]
87+
7388
img = skimage.io.imread(fname)
7489
img = normalize(img, 255)
7590

@@ -85,8 +100,10 @@ def load_image(fname: str):
85100
return img
86101

87102

88-
def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool = True) -> ndarray:
89-
"""read a dicom-like file and convert to numpy array
103+
def read_xray_dcm(
104+
path: PathLike, voi_lut: bool = False, fix_monochrome: bool = True
105+
) -> ndarray:
106+
"""read a dicom-like file and convert to numpy array
90107
91108
Args:
92109
path (PathLike): path to the dicom file
@@ -99,35 +116,43 @@ def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool =
99116
try:
100117
import pydicom
101118
except ImportError:
102-
raise Exception("Missing Package Pydicom. Try installing it by running `pip install pydicom`.")
119+
raise ImportError(
120+
"Missing Package Pydicom. Try installing it by running `pip install pydicom`."
121+
)
103122

104123
# get the pixel array
105124
ds = pydicom.dcmread(path, force=True)
106125

107126
# we have not tested RGB, YBR_FULL, or YBR_FULL_422 yet.
108-
if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']:
109-
raise NotImplementedError(f'PhotometricInterpretation `{ds.PhotometricInterpretation}` is not yet supported.')
127+
if ds.PhotometricInterpretation not in ["MONOCHROME1", "MONOCHROME2"]:
128+
raise NotImplementedError(
129+
f"PhotometricInterpretation `{ds.PhotometricInterpretation}` is not yet supported."
130+
)
110131
# get the max possible pixel value from DCM header
111-
max_possible_pixel_val = (2**ds.BitsStored - 1)
132+
max_possible_pixel_val = 2**ds.BitsStored - 1
112133

113134
data = ds.pixel_array
114-
135+
115136
# LUT for human friendly view
116137
if voi_lut:
117138
data = pydicom.pixel_data_handlers.util.apply_voi_lut(data, ds, index=0)
118139

119140
# `MONOCHROME1` have an inverted view; Bones are black; background is white
120141
# https://web.archive.org/web/20150920230923/http://www.mccauslandcenter.sc.edu/mricro/dicom/index.html
121142
if fix_monochrome and ds.PhotometricInterpretation == "MONOCHROME1":
122-
warnings.warn(f"Coverting MONOCHROME1 to MONOCHROME2 interpretation for file: {path}. Can be avoided by setting `fix_monochrome=False`")
143+
warnings.warn(
144+
f"Converting MONOCHROME1 to MONOCHROME2 interpretation for file: {path}. Can be avoided by setting `fix_monochrome=False`"
145+
)
123146
data = max_possible_pixel_val - data
124147

125148
# normalize data to [-1024, 1024]
126149
data = normalize(data, max_possible_pixel_val)
127150
return data
128151

129152

130-
def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4, device='cpu'):
153+
def infer(
154+
model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4, device="cpu"
155+
):
131156

132157
dl = torch.utils.data.DataLoader(
133158
dataset,
@@ -148,37 +173,50 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4,
148173

149174
warning_log = {}
150175

176+
151177
def fix_resolution(x, resolution: int, model):
152178
"""Check resolution of input and resize to match requested."""
153179

154180
if len(x.shape) == 3:
155181
# Extend to be 4D
156-
x = x[None,...]
182+
x = x[None, ...]
157183

158184
if x.shape[2] != x.shape[3]:
159-
raise Exception(f"Height and width of the image must be the same. Input: {x.shape[2]} != {x.shape[3]}. Perform a center crop first.")
160-
161-
if (x.shape[2] != resolution) | (x.shape[3] != resolution):
185+
raise Exception(
186+
f"Height and width of the image must be the same. Input: {x.shape[2]} != {x.shape[3]}. Perform a center crop first."
187+
)
188+
189+
if (x.shape[2] != resolution) or (x.shape[3] != resolution):
162190
if not hash(model) in warning_log:
163-
print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution))
191+
print(
192+
"Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(
193+
x.shape[2], x.shape[3], resolution, resolution
194+
)
195+
)
164196
warning_log[hash(model)] = True
165-
return torch.nn.functional.interpolate(x, size=(resolution, resolution), mode='bilinear', antialias=True)
197+
return torch.nn.functional.interpolate(
198+
x, size=(resolution, resolution), mode="bilinear", antialias=True
199+
)
166200
return x
167201

168202

169203
def warn_normalization(x):
170-
"""Check normalization of input and warn if possibly wrong. When
171-
processing an image that may likely not have the correct
172-
normalization we can issue a warning. But running min and max on
204+
"""Check normalization of input and warn if possibly wrong. When
205+
processing an image that may likely not have the correct
206+
normalization we can issue a warning. But running min and max on
173207
every image/batch is costly so we only do it on the first image/batch.
174208
"""
175209

176210
# Only run this check on the first image so we don't hurt performance.
177211
if not "norm_check" in warning_log:
178212
x_min = x.min()
179213
x_max = x.max()
180-
if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(x_min < -1025, 1025 < x_max):
181-
print(f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn\'t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.')
214+
if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(
215+
x_min < -1025, 1025 < x_max
216+
):
217+
print(
218+
f"Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn't seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization."
219+
)
182220
warning_log["norm_correct"] = False
183221
else:
184222
warning_log["norm_correct"] = True

0 commit comments

Comments
 (0)