8
8
from os import PathLike
9
9
from numpy import ndarray
10
10
import warnings
11
- from tqdm .autonotebook import tqdm
11
+ from tqdm .auto import tqdm
12
12
13
13
14
14
def get_cache_dir ():
15
15
return os .path .expanduser (os .path .join ("~" , ".torchxrayvision" , "models_data/" ))
16
16
17
+
17
18
def in_notebook ():
18
19
try :
19
20
from IPython import get_ipython
20
- if 'IPKernelApp' not in get_ipython ().config : # pragma: no cover
21
+
22
+ if "IPKernelApp" not in get_ipython ().config : # pragma: no cover
21
23
return False
22
24
except ImportError :
23
25
return False
@@ -28,31 +30,37 @@ def in_notebook():
28
30
29
31
# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
30
32
def download (url : str , filename : str ):
31
- with open (filename , 'wb' ) as f :
33
+ with open (filename , "wb" ) as f :
32
34
response = requests .get (url , stream = True )
33
- total = response .headers .get (' content-length' )
35
+ total = response .headers .get (" content-length" )
34
36
35
37
if total is None :
36
38
f .write (response .content )
37
39
else :
38
40
downloaded = 0
39
41
total = int (total )
40
- for data in response .iter_content (chunk_size = max (int (total / 1000 ), 1024 * 1024 )):
42
+ for data in response .iter_content (
43
+ chunk_size = max (int (total / 1000 ), 1024 * 1024 )
44
+ ):
41
45
downloaded += len (data )
42
46
f .write (data )
43
47
done = int (50 * downloaded / total )
44
- sys .stdout .write (' \r [{}{}]' .format ('█' * done , '.' * (50 - done )))
48
+ sys .stdout .write (" \r [{}{}]" .format ("█" * done , "." * (50 - done )))
45
49
sys .stdout .flush ()
46
- sys .stdout .write (' \n ' )
50
+ sys .stdout .write (" \n " )
47
51
48
52
49
53
def normalize (img , maxval , reshape = False ):
50
54
"""Scales images to be roughly [-1024 1024]."""
51
55
52
56
if img .max () > maxval :
53
- raise Exception ("max image value ({}) higher than expected bound ({})." .format (img .max (), maxval ))
57
+ raise ValueError (
58
+ "max image value ({}) higher than expected bound ({})." .format (
59
+ img .max (), maxval
60
+ )
61
+ )
54
62
55
- img = (2 * (img .astype (np .float32 ) / maxval ) - 1. ) * 1024
63
+ img = (2 * (img .astype (np .float32 ) / maxval ) - 1.0 ) * 1024
56
64
57
65
if reshape :
58
66
# Check that images are 2D arrays
@@ -70,6 +78,13 @@ def normalize(img, maxval, reshape=False):
70
78
def load_image (fname : str ):
71
79
"""Load an image from a file and normalize it between -1024 and 1024. Assumes 8-bits per pixel."""
72
80
81
+ with open (fname , "rb" ) as f :
82
+ # Read the first 132 bytes (128 preamble + 4 for "DICM")
83
+ header = f .read (132 )
84
+ # Check if the file is long enough and has "DICM" at position 128
85
+ if len (header ) >= 132 and header [128 :132 ] == b"DICM" :
86
+ return read_xray_dcm (fname )[None , ...]
87
+
73
88
img = skimage .io .imread (fname )
74
89
img = normalize (img , 255 )
75
90
@@ -85,8 +100,10 @@ def load_image(fname: str):
85
100
return img
86
101
87
102
88
- def read_xray_dcm (path : PathLike , voi_lut : bool = False , fix_monochrome : bool = True ) -> ndarray :
89
- """read a dicom-like file and convert to numpy array
103
+ def read_xray_dcm (
104
+ path : PathLike , voi_lut : bool = False , fix_monochrome : bool = True
105
+ ) -> ndarray :
106
+ """read a dicom-like file and convert to numpy array
90
107
91
108
Args:
92
109
path (PathLike): path to the dicom file
@@ -99,35 +116,43 @@ def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool =
99
116
try :
100
117
import pydicom
101
118
except ImportError :
102
- raise Exception ("Missing Package Pydicom. Try installing it by running `pip install pydicom`." )
119
+ raise ImportError (
120
+ "Missing Package Pydicom. Try installing it by running `pip install pydicom`."
121
+ )
103
122
104
123
# get the pixel array
105
124
ds = pydicom .dcmread (path , force = True )
106
125
107
126
# we have not tested RGB, YBR_FULL, or YBR_FULL_422 yet.
108
- if ds .PhotometricInterpretation not in ['MONOCHROME1' , 'MONOCHROME2' ]:
109
- raise NotImplementedError (f'PhotometricInterpretation `{ ds .PhotometricInterpretation } ` is not yet supported.' )
127
+ if ds .PhotometricInterpretation not in ["MONOCHROME1" , "MONOCHROME2" ]:
128
+ raise NotImplementedError (
129
+ f"PhotometricInterpretation `{ ds .PhotometricInterpretation } ` is not yet supported."
130
+ )
110
131
# get the max possible pixel value from DCM header
111
- max_possible_pixel_val = ( 2 ** ds .BitsStored - 1 )
132
+ max_possible_pixel_val = 2 ** ds .BitsStored - 1
112
133
113
134
data = ds .pixel_array
114
-
135
+
115
136
# LUT for human friendly view
116
137
if voi_lut :
117
138
data = pydicom .pixel_data_handlers .util .apply_voi_lut (data , ds , index = 0 )
118
139
119
140
# `MONOCHROME1` have an inverted view; Bones are black; background is white
120
141
# https://web.archive.org/web/20150920230923/http://www.mccauslandcenter.sc.edu/mricro/dicom/index.html
121
142
if fix_monochrome and ds .PhotometricInterpretation == "MONOCHROME1" :
122
- warnings .warn (f"Coverting MONOCHROME1 to MONOCHROME2 interpretation for file: { path } . Can be avoided by setting `fix_monochrome=False`" )
143
+ warnings .warn (
144
+ f"Converting MONOCHROME1 to MONOCHROME2 interpretation for file: { path } . Can be avoided by setting `fix_monochrome=False`"
145
+ )
123
146
data = max_possible_pixel_val - data
124
147
125
148
# normalize data to [-1024, 1024]
126
149
data = normalize (data , max_possible_pixel_val )
127
150
return data
128
151
129
152
130
- def infer (model : torch .nn .Module , dataset : torch .utils .data .Dataset , threads = 4 , device = 'cpu' ):
153
+ def infer (
154
+ model : torch .nn .Module , dataset : torch .utils .data .Dataset , threads = 4 , device = "cpu"
155
+ ):
131
156
132
157
dl = torch .utils .data .DataLoader (
133
158
dataset ,
@@ -148,37 +173,50 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4,
148
173
149
174
warning_log = {}
150
175
176
+
151
177
def fix_resolution (x , resolution : int , model ):
152
178
"""Check resolution of input and resize to match requested."""
153
179
154
180
if len (x .shape ) == 3 :
155
181
# Extend to be 4D
156
- x = x [None ,...]
182
+ x = x [None , ...]
157
183
158
184
if x .shape [2 ] != x .shape [3 ]:
159
- raise Exception (f"Height and width of the image must be the same. Input: { x .shape [2 ]} != { x .shape [3 ]} . Perform a center crop first." )
160
-
161
- if (x .shape [2 ] != resolution ) | (x .shape [3 ] != resolution ):
185
+ raise Exception (
186
+ f"Height and width of the image must be the same. Input: { x .shape [2 ]} != { x .shape [3 ]} . Perform a center crop first."
187
+ )
188
+
189
+ if (x .shape [2 ] != resolution ) or (x .shape [3 ] != resolution ):
162
190
if not hash (model ) in warning_log :
163
- print ("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance." .format (x .shape [2 ], x .shape [3 ], resolution , resolution ))
191
+ print (
192
+ "Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance." .format (
193
+ x .shape [2 ], x .shape [3 ], resolution , resolution
194
+ )
195
+ )
164
196
warning_log [hash (model )] = True
165
- return torch .nn .functional .interpolate (x , size = (resolution , resolution ), mode = 'bilinear' , antialias = True )
197
+ return torch .nn .functional .interpolate (
198
+ x , size = (resolution , resolution ), mode = "bilinear" , antialias = True
199
+ )
166
200
return x
167
201
168
202
169
203
def warn_normalization (x ):
170
- """Check normalization of input and warn if possibly wrong. When
171
- processing an image that may likely not have the correct
172
- normalization we can issue a warning. But running min and max on
204
+ """Check normalization of input and warn if possibly wrong. When
205
+ processing an image that may likely not have the correct
206
+ normalization we can issue a warning. But running min and max on
173
207
every image/batch is costly so we only do it on the first image/batch.
174
208
"""
175
209
176
210
# Only run this check on the first image so we don't hurt performance.
177
211
if not "norm_check" in warning_log :
178
212
x_min = x .min ()
179
213
x_max = x .max ()
180
- if torch .logical_or (- 255 < x_min , x_max < 255 ) or torch .logical_or (x_min < - 1025 , 1025 < x_max ):
181
- print (f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{ x_min :.2f} ,{ x_max :.2f} ] which doesn\' t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.' )
214
+ if torch .logical_or (- 255 < x_min , x_max < 255 ) or torch .logical_or (
215
+ x_min < - 1025 , 1025 < x_max
216
+ ):
217
+ print (
218
+ f"Warning: Input image does not appear to be normalized correctly. The input image has the range [{ x_min :.2f} ,{ x_max :.2f} ] which doesn't seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization."
219
+ )
182
220
warning_log ["norm_correct" ] = False
183
221
else :
184
222
warning_log ["norm_correct" ] = True
0 commit comments