Skip to content

Commit 435bee8

Browse files
committed
autopep8
1 parent a0f156a commit 435bee8

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

captum/attr/_core/latent_shift.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ class LatentShift(GradientAttribution):
2020
the possible adversarial examples to remain in the data space by
2121
adjusting the latent space of the autoencoder using dy/dz instead of
2222
dy/dx in order to change the classifier's prediction.
23-
23+
2424
This class implements a search strategy to determine the lambda needed to
2525
change the prediction of the classifier by a specific amount as well as
2626
the code to generate a video and construct a heatmap representing the
2727
image changes for viewing as an image.
28-
28+
2929
Publication:
30-
Cohen, J. P., et al. Gifsplanation via Latent Shift: A Simple
31-
Autoencoder Approach to Counterfactual Generation for Chest
32-
X-rays. Medical Imaging with Deep Learning.
30+
Cohen, J. P., et al. Gifsplanation via Latent Shift: A Simple
31+
Autoencoder Approach to Counterfactual Generation for Chest
32+
X-rays. Medical Imaging with Deep Learning.
3333
https://arxiv.org/abs/2102.09475
3434
"""
3535

@@ -44,7 +44,7 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
4444
"""
4545
GradientAttribution.__init__(self, forward_func)
4646
self.ae = autoencoder
47-
47+
4848
# check if ae has encode and decode
4949
assert hasattr(self.ae, 'encode')
5050
assert hasattr(self.ae, 'decode')
@@ -109,13 +109,13 @@ def attribute(
109109
Returns:
110110
dict containing the follow keys:
111111
generated_images: A list of images generated at each step along
112-
the dydz vector from the smallest lambda to the largest. By
113-
default the smallest lambda represents the counterfactual
112+
the dydz vector from the smallest lambda to the largest. By
113+
default the smallest lambda represents the counterfactual
114114
image and the largest lambda is 0 (representing no change).
115115
lambdas: A list of the lambda values for each generated image.
116-
preds: A list of the predictions of the model for each generated
116+
preds: A list of the predictions of the model for each generated
117117
image.
118-
heatmap: A heatmap indicating the pixels which change in the
118+
heatmap: A heatmap indicating the pixels which change in the
119119
video sequence of images.
120120
121121
@@ -124,13 +124,13 @@ def attribute(
124124
>>> # Load classifier and autoencoder
125125
>>> model = classifiers.FaceAttribute()
126126
>>> ae = autoencoders.Transformer(weights="faceshq")
127-
>>>
127+
>>>
128128
>>> # Load image
129129
>>> x = torch.randn(1, 3, 1024, 1024)
130-
>>>
130+
>>>
131131
>>> # Defining Latent Shift module
132132
>>> attr = captum.attr.LatentShift(model, ae)
133-
>>>
133+
>>>
134134
>>> # Computes counterfactual for class 3.
135135
>>> output = attr.attribute(x, target=3)
136136
@@ -140,7 +140,7 @@ def attribute(
140140
x_lambda0 = self.ae.decode(z)
141141
pred = torch.sigmoid(self.forward_func(x_lambda0))[:, target]
142142
dzdxp = torch.autograd.grad(pred, z)[0]
143-
143+
144144
# Cache so we can reuse at sweep stage
145145
cache = {}
146146

@@ -149,14 +149,14 @@ def compute_shift(lambdax):
149149
if lambdax not in cache:
150150
x_lambdax = self.ae.decode(z + dzdxp * lambdax).detach()
151151
pred1 = torch.sigmoid(self.forward_func(x_lambdax))[:, target]
152-
pred1 = pred1.detach().cpu().numpy()
152+
pred1 = pred1.detach().cpu().numpy()
153153
cache[lambdax] = x_lambdax, pred1
154154
if verbose:
155155
print(f'Shift: {lambdax} , Prediction: {pred1}')
156156
return cache[lambdax]
157-
157+
158158
_, initial_pred = compute_shift(0)
159-
159+
160160
if fix_range:
161161
lbound, rbound = fix_range
162162
else:
@@ -166,7 +166,7 @@ def compute_shift(lambdax):
166166
while True:
167167
x_lambdax, cur_pred = compute_shift(lbound)
168168
pixel_diff = torch.abs(x_lambda0 - x_lambdax).sum().detach()
169-
169+
170170
# If we stop decreasing the prediction
171171
if last_pred < cur_pred:
172172
break
@@ -182,7 +182,7 @@ def compute_shift(lambdax):
182182
# If we move too far we will distort the image
183183
if pixel_diff > search_max_pixel_diff:
184184
break
185-
185+
186186
last_pred = cur_pred
187187
lbound = lbound - search_step_size + lbound // 10
188188

@@ -191,22 +191,22 @@ def compute_shift(lambdax):
191191

192192
if verbose:
193193
print('Selected bounds: ', lbound, rbound)
194-
194+
195195
# Sweep over the range of lambda values to create a sequence
196196
lambdas = np.arange(
197197
lbound,
198198
rbound,
199199
np.abs((lbound - rbound) / lambda_sweep_steps)
200200
)
201-
201+
202202
preds = []
203203
generated_images = []
204-
204+
205205
for lam in lambdas:
206206
x_lambdax, pred = compute_shift(lam)
207207
generated_images.append(x_lambdax.cpu().numpy())
208208
preds.append(pred)
209-
209+
210210
params = {}
211211
params['generated_images'] = generated_images
212212
params['lambdas'] = lambdas
@@ -219,21 +219,21 @@ def compute_shift(lambdax):
219219
heatmap = np.max(
220220
np.abs(x_lambda0[0][0] - generated_images[0][0]), 0
221221
)
222-
222+
223223
elif heatmap_method == 'mean':
224224
# Average difference between 0 and other lambda frames
225225
heatmap = np.mean(
226226
np.abs(x_lambda0[0][0] - generated_images[0][0]), 0
227227
)
228-
228+
229229
elif heatmap_method == 'mm':
230230
# Difference between first and last frames
231231
heatmap = np.abs(
232232
generated_images[0][0][0] - generated_images[-1][0][0]
233233
)
234-
234+
235235
elif heatmap_method == 'int':
236-
# Average per frame differences
236+
# Average per frame differences
237237
image_changes = []
238238
for i in range(len(generated_images) - 1):
239239
image_changes.append(np.abs(
@@ -242,11 +242,11 @@ def compute_shift(lambdax):
242242
heatmap = np.mean(image_changes, 0)
243243
else:
244244
raise Exception('Unknown heatmap_method for 2d image')
245-
245+
246246
params["heatmap"] = heatmap
247-
247+
248248
return params
249-
249+
250250
@log_usage()
251251
def generate_video(
252252
self,
@@ -275,24 +275,24 @@ def generate_video(
275275
The filename of the video if show=False, otherwise it will
276276
return a video to show in a jupyter notebook.
277277
"""
278-
278+
279279
if not target_filename:
280280
target_filename = f'video-{params["target"]}'
281-
281+
282282
if os.path.exists(target_filename + ".mp4"):
283-
os.remove(target_filename + ".mp4")
284-
285-
shutil.rmtree(temp_path, ignore_errors=True)
283+
os.remove(target_filename + ".mp4")
284+
285+
shutil.rmtree(temp_path, ignore_errors=True)
286286
os.mkdir(temp_path)
287-
287+
288288
imgs = [h.transpose(0, 2, 3, 1) for h in params["generated_images"]]
289289

290290
# Add reversed so we have an animation cycle
291291
towrite = list(reversed(imgs)) + list(imgs)
292292
ys = list(reversed(params['preds'])) + list(params['preds'])
293-
293+
294294
for idx, img in enumerate(towrite):
295-
295+
296296
px = 1 / plt.rcParams['figure.dpi']
297297
full_frame(img[0].shape[0] * px, img[0].shape[1] * px)
298298
plt.imshow(img[0], interpolation='none')
@@ -310,7 +310,7 @@ def generate_video(
310310
ha='right', va='bottom',
311311
transform=plt.gca().transAxes
312312
)
313-
313+
314314
plt.savefig(
315315
f'{temp_path}/image-{idx}.png',
316316
bbox_inches='tight',
@@ -346,7 +346,7 @@ def generate_video(
346346
)
347347
else:
348348
return target_filename + ".mp4"
349-
349+
350350

351351
def full_frame(width=None, height=None):
352352
"""Setup matplotlib so we can write to the entire canvas"""

0 commit comments

Comments
 (0)