@@ -20,16 +20,16 @@ class LatentShift(GradientAttribution):
20
20
the possible adversarial examples to remain in the data space by
21
21
adjusting the latent space of the autoencoder using dy/dz instead of
22
22
dy/dx in order to change the classifier's prediction.
23
-
23
+
24
24
This class implements a search strategy to determine the lambda needed to
25
25
change the prediction of the classifier by a specific amount as well as
26
26
the code to generate a video and construct a heatmap representing the
27
27
image changes for viewing as an image.
28
-
28
+
29
29
Publication:
30
- Cohen, J. P., et al. Gifsplanation via Latent Shift: A Simple
31
- Autoencoder Approach to Counterfactual Generation for Chest
32
- X-rays. Medical Imaging with Deep Learning.
30
+ Cohen, J. P., et al. Gifsplanation via Latent Shift: A Simple
31
+ Autoencoder Approach to Counterfactual Generation for Chest
32
+ X-rays. Medical Imaging with Deep Learning.
33
33
https://arxiv.org/abs/2102.09475
34
34
"""
35
35
@@ -44,7 +44,7 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
44
44
"""
45
45
GradientAttribution .__init__ (self , forward_func )
46
46
self .ae = autoencoder
47
-
47
+
48
48
# check if ae has encode and decode
49
49
assert hasattr (self .ae , 'encode' )
50
50
assert hasattr (self .ae , 'decode' )
@@ -109,13 +109,13 @@ def attribute(
109
109
Returns:
110
110
dict containing the follow keys:
111
111
generated_images: A list of images generated at each step along
112
- the dydz vector from the smallest lambda to the largest. By
113
- default the smallest lambda represents the counterfactual
112
+ the dydz vector from the smallest lambda to the largest. By
113
+ default the smallest lambda represents the counterfactual
114
114
image and the largest lambda is 0 (representing no change).
115
115
lambdas: A list of the lambda values for each generated image.
116
- preds: A list of the predictions of the model for each generated
116
+ preds: A list of the predictions of the model for each generated
117
117
image.
118
- heatmap: A heatmap indicating the pixels which change in the
118
+ heatmap: A heatmap indicating the pixels which change in the
119
119
video sequence of images.
120
120
121
121
@@ -124,13 +124,13 @@ def attribute(
124
124
>>> # Load classifier and autoencoder
125
125
>>> model = classifiers.FaceAttribute()
126
126
>>> ae = autoencoders.Transformer(weights="faceshq")
127
- >>>
127
+ >>>
128
128
>>> # Load image
129
129
>>> x = torch.randn(1, 3, 1024, 1024)
130
- >>>
130
+ >>>
131
131
>>> # Defining Latent Shift module
132
132
>>> attr = captum.attr.LatentShift(model, ae)
133
- >>>
133
+ >>>
134
134
>>> # Computes counterfactual for class 3.
135
135
>>> output = attr.attribute(x, target=3)
136
136
@@ -140,7 +140,7 @@ def attribute(
140
140
x_lambda0 = self .ae .decode (z )
141
141
pred = torch .sigmoid (self .forward_func (x_lambda0 ))[:, target ]
142
142
dzdxp = torch .autograd .grad (pred , z )[0 ]
143
-
143
+
144
144
# Cache so we can reuse at sweep stage
145
145
cache = {}
146
146
@@ -149,14 +149,14 @@ def compute_shift(lambdax):
149
149
if lambdax not in cache :
150
150
x_lambdax = self .ae .decode (z + dzdxp * lambdax ).detach ()
151
151
pred1 = torch .sigmoid (self .forward_func (x_lambdax ))[:, target ]
152
- pred1 = pred1 .detach ().cpu ().numpy ()
152
+ pred1 = pred1 .detach ().cpu ().numpy ()
153
153
cache [lambdax ] = x_lambdax , pred1
154
154
if verbose :
155
155
print (f'Shift: { lambdax } , Prediction: { pred1 } ' )
156
156
return cache [lambdax ]
157
-
157
+
158
158
_ , initial_pred = compute_shift (0 )
159
-
159
+
160
160
if fix_range :
161
161
lbound , rbound = fix_range
162
162
else :
@@ -166,7 +166,7 @@ def compute_shift(lambdax):
166
166
while True :
167
167
x_lambdax , cur_pred = compute_shift (lbound )
168
168
pixel_diff = torch .abs (x_lambda0 - x_lambdax ).sum ().detach ()
169
-
169
+
170
170
# If we stop decreasing the prediction
171
171
if last_pred < cur_pred :
172
172
break
@@ -182,7 +182,7 @@ def compute_shift(lambdax):
182
182
# If we move too far we will distort the image
183
183
if pixel_diff > search_max_pixel_diff :
184
184
break
185
-
185
+
186
186
last_pred = cur_pred
187
187
lbound = lbound - search_step_size + lbound // 10
188
188
@@ -191,22 +191,22 @@ def compute_shift(lambdax):
191
191
192
192
if verbose :
193
193
print ('Selected bounds: ' , lbound , rbound )
194
-
194
+
195
195
# Sweep over the range of lambda values to create a sequence
196
196
lambdas = np .arange (
197
197
lbound ,
198
198
rbound ,
199
199
np .abs ((lbound - rbound ) / lambda_sweep_steps )
200
200
)
201
-
201
+
202
202
preds = []
203
203
generated_images = []
204
-
204
+
205
205
for lam in lambdas :
206
206
x_lambdax , pred = compute_shift (lam )
207
207
generated_images .append (x_lambdax .cpu ().numpy ())
208
208
preds .append (pred )
209
-
209
+
210
210
params = {}
211
211
params ['generated_images' ] = generated_images
212
212
params ['lambdas' ] = lambdas
@@ -219,21 +219,21 @@ def compute_shift(lambdax):
219
219
heatmap = np .max (
220
220
np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0
221
221
)
222
-
222
+
223
223
elif heatmap_method == 'mean' :
224
224
# Average difference between 0 and other lambda frames
225
225
heatmap = np .mean (
226
226
np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0
227
227
)
228
-
228
+
229
229
elif heatmap_method == 'mm' :
230
230
# Difference between first and last frames
231
231
heatmap = np .abs (
232
232
generated_images [0 ][0 ][0 ] - generated_images [- 1 ][0 ][0 ]
233
233
)
234
-
234
+
235
235
elif heatmap_method == 'int' :
236
- # Average per frame differences
236
+ # Average per frame differences
237
237
image_changes = []
238
238
for i in range (len (generated_images ) - 1 ):
239
239
image_changes .append (np .abs (
@@ -242,11 +242,11 @@ def compute_shift(lambdax):
242
242
heatmap = np .mean (image_changes , 0 )
243
243
else :
244
244
raise Exception ('Unknown heatmap_method for 2d image' )
245
-
245
+
246
246
params ["heatmap" ] = heatmap
247
-
247
+
248
248
return params
249
-
249
+
250
250
@log_usage ()
251
251
def generate_video (
252
252
self ,
@@ -275,24 +275,24 @@ def generate_video(
275
275
The filename of the video if show=False, otherwise it will
276
276
return a video to show in a jupyter notebook.
277
277
"""
278
-
278
+
279
279
if not target_filename :
280
280
target_filename = f'video-{ params ["target" ]} '
281
-
281
+
282
282
if os .path .exists (target_filename + ".mp4" ):
283
- os .remove (target_filename + ".mp4" )
284
-
285
- shutil .rmtree (temp_path , ignore_errors = True )
283
+ os .remove (target_filename + ".mp4" )
284
+
285
+ shutil .rmtree (temp_path , ignore_errors = True )
286
286
os .mkdir (temp_path )
287
-
287
+
288
288
imgs = [h .transpose (0 , 2 , 3 , 1 ) for h in params ["generated_images" ]]
289
289
290
290
# Add reversed so we have an animation cycle
291
291
towrite = list (reversed (imgs )) + list (imgs )
292
292
ys = list (reversed (params ['preds' ])) + list (params ['preds' ])
293
-
293
+
294
294
for idx , img in enumerate (towrite ):
295
-
295
+
296
296
px = 1 / plt .rcParams ['figure.dpi' ]
297
297
full_frame (img [0 ].shape [0 ] * px , img [0 ].shape [1 ] * px )
298
298
plt .imshow (img [0 ], interpolation = 'none' )
@@ -310,7 +310,7 @@ def generate_video(
310
310
ha = 'right' , va = 'bottom' ,
311
311
transform = plt .gca ().transAxes
312
312
)
313
-
313
+
314
314
plt .savefig (
315
315
f'{ temp_path } /image-{ idx } .png' ,
316
316
bbox_inches = 'tight' ,
@@ -346,7 +346,7 @@ def generate_video(
346
346
)
347
347
else :
348
348
return target_filename + ".mp4"
349
-
349
+
350
350
351
351
def full_frame (width = None , height = None ):
352
352
"""Setup matplotlib so we can write to the entire canvas"""
0 commit comments