Skip to content

Commit 3f9bbdd

Browse files
committed
ufmt format
1 parent 2f618ff commit 3f9bbdd

File tree

3 files changed

+78
-79
lines changed

3 files changed

+78
-79
lines changed

captum/attr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from captum.attr._core.input_x_gradient import InputXGradient # noqa
1212
from captum.attr._core.integrated_gradients import IntegratedGradients # noqa
1313
from captum.attr._core.kernel_shap import KernelShap # noqa
14+
from captum.attr._core.latent_shift import LatentShift # noqa
1415
from captum.attr._core.layer.grad_cam import LayerGradCam # noqa
1516
from captum.attr._core.layer.internal_influence import InternalInfluence # noqa
1617
from captum.attr._core.layer.layer_activation import LayerActivation # noqa
@@ -50,7 +51,6 @@
5051
from captum.attr._core.noise_tunnel import NoiseTunnel # noqa
5152
from captum.attr._core.occlusion import Occlusion # noqa
5253
from captum.attr._core.saliency import Saliency # noqa
53-
from captum.attr._core.latent_shift import LatentShift # noqa
5454
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling # noqa
5555
from captum.attr._models.base import ( # noqa
5656
configure_interpretable_embedding_layer,

captum/attr/_core/latent_shift.py

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#!/usr/bin/env python3
22

3-
from typing import Callable, Tuple, List
4-
53
import os
64
import shutil
7-
import torch
8-
import numpy as np
9-
from torch import Tensor
10-
from captum.attr._utils.attribution import GradientAttribution
11-
from captum.log import log_usage
125
import subprocess
6+
from typing import Callable, List, Tuple
7+
138
import matplotlib
149
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import torch
12+
from captum.attr._utils.attribution import GradientAttribution
13+
from captum.log import log_usage
14+
from torch import Tensor
1515

1616

1717
class LatentShift(GradientAttribution):
@@ -46,8 +46,8 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
4646
self.ae = autoencoder
4747

4848
# check if ae has encode and decode
49-
assert hasattr(self.ae, 'encode')
50-
assert hasattr(self.ae, 'decode')
49+
assert hasattr(self.ae, "encode")
50+
assert hasattr(self.ae, "decode")
5151

5252
@log_usage()
5353
def attribute(
@@ -60,7 +60,7 @@ def attribute(
6060
search_max_steps: int = 3000,
6161
search_max_pixel_diff: float = 5000.0,
6262
lambda_sweep_steps: int = 10,
63-
heatmap_method: str = 'int',
63+
heatmap_method: str = "int",
6464
verbose: bool = True,
6565
) -> dict:
6666
r"""
@@ -152,7 +152,7 @@ def compute_shift(lambdax):
152152
pred1 = pred1.detach().cpu().numpy()
153153
cache[lambdax] = x_lambdax, pred1
154154
if verbose:
155-
print(f'Shift: {lambdax} , Prediction: {pred1}')
155+
print(f"Shift: {lambdax} , Prediction: {pred1}")
156156
return cache[lambdax]
157157

158158
_, initial_pred = compute_shift(0)
@@ -190,13 +190,11 @@ def compute_shift(lambdax):
190190
rbound = 0
191191

192192
if verbose:
193-
print('Selected bounds: ', lbound, rbound)
193+
print("Selected bounds: ", lbound, rbound)
194194

195195
# Sweep over the range of lambda values to create a sequence
196196
lambdas = np.arange(
197-
lbound,
198-
rbound,
199-
np.abs((lbound - rbound) / lambda_sweep_steps)
197+
lbound, rbound, np.abs((lbound - rbound) / lambda_sweep_steps)
200198
).tolist()
201199

202200
preds = []
@@ -208,39 +206,33 @@ def compute_shift(lambdax):
208206
preds.append(float(pred))
209207

210208
params = {}
211-
params['generated_images'] = generated_images
212-
params['lambdas'] = lambdas
213-
params['preds'] = preds
209+
params["generated_images"] = generated_images
210+
params["lambdas"] = lambdas
211+
params["preds"] = preds
214212

215213
x_lambda0 = x_lambda0.detach().cpu().numpy()
216-
if heatmap_method == 'max':
214+
if heatmap_method == "max":
217215
# Max difference from lambda 0 frame
218-
heatmap = np.max(
219-
np.abs(x_lambda0[0][0] - generated_images[0][0]), 0
220-
)
216+
heatmap = np.max(np.abs(x_lambda0[0][0] - generated_images[0][0]), 0)
221217

222-
elif heatmap_method == 'mean':
218+
elif heatmap_method == "mean":
223219
# Average difference between 0 and other lambda frames
224-
heatmap = np.mean(
225-
np.abs(x_lambda0[0][0] - generated_images[0][0]), 0
226-
)
220+
heatmap = np.mean(np.abs(x_lambda0[0][0] - generated_images[0][0]), 0)
227221

228-
elif heatmap_method == 'mm':
222+
elif heatmap_method == "mm":
229223
# Difference between first and last frames
230-
heatmap = np.abs(
231-
generated_images[0][0][0] - generated_images[-1][0][0]
232-
)
224+
heatmap = np.abs(generated_images[0][0][0] - generated_images[-1][0][0])
233225

234-
elif heatmap_method == 'int':
226+
elif heatmap_method == "int":
235227
# Average per frame differences
236228
image_changes = []
237229
for i in range(len(generated_images) - 1):
238-
image_changes.append(np.abs(
239-
generated_images[i][0][0] - generated_images[i + 1][0][0]
240-
))
230+
image_changes.append(
231+
np.abs(generated_images[i][0][0] - generated_images[i + 1][0][0])
232+
)
241233
heatmap = np.mean(image_changes, 0)
242234
else:
243-
raise Exception('Unknown heatmap_method for 2d image')
235+
raise Exception("Unknown heatmap_method for 2d image")
244236

245237
params["heatmap"] = heatmap
246238

@@ -259,20 +251,20 @@ def generate_video(
259251
):
260252
"""Generate a video from the generated images.
261253
262-
Args:
263-
params: The dict returned from the call to `attribute`.
264-
target_filename: The filename to write the video to. `.mp4` will
265-
be added to the end of the string.
266-
watermark: To add the probability output and the name of the
267-
method.
268-
ffmpeg_path: The path to call `ffmpeg`
269-
temp_path: A temp path to write images.
270-
show: To try and show the video in a jupyter notebook.
271-
verbose: True to print debug text
272-
273-
Returns:
274-
The filename of the video if show=False, otherwise it will
275-
return a video to show in a jupyter notebook.
254+
Args:
255+
params: The dict returned from the call to `attribute`.
256+
target_filename: The filename to write the video to. `.mp4` will
257+
be added to the end of the string.
258+
watermark: To add the probability output and the name of the
259+
method.
260+
ffmpeg_path: The path to call `ffmpeg`
261+
temp_path: A temp path to write images.
262+
show: To try and show the video in a jupyter notebook.
263+
verbose: True to print debug text
264+
265+
Returns:
266+
The filename of the video if show=False, otherwise it will
267+
return a video to show in a jupyter notebook.
276268
"""
277269

278270
if os.path.exists(target_filename + ".mp4"):
@@ -285,42 +277,50 @@ def generate_video(
285277

286278
# Add reversed so we have an animation cycle
287279
towrite = list(reversed(imgs)) + list(imgs)
288-
ys = list(reversed(params['preds'])) + list(params['preds'])
280+
ys = list(reversed(params["preds"])) + list(params["preds"])
289281

290282
for idx, img in enumerate(towrite):
291283

292-
px = 1 / plt.rcParams['figure.dpi']
284+
px = 1 / plt.rcParams["figure.dpi"]
293285
full_frame(img[0].shape[0] * px, img[0].shape[1] * px)
294-
plt.imshow(img[0], interpolation='none')
286+
plt.imshow(img[0], interpolation="none")
295287

296288
if watermark:
297289
# Write prob output in upper left
298290
plt.text(
299-
0.05, 0.95, f"{float(ys[idx]):1.1f}",
300-
ha='left', va='top',
301-
transform=plt.gca().transAxes
291+
0.05,
292+
0.95,
293+
f"{float(ys[idx]):1.1f}",
294+
ha="left",
295+
va="top",
296+
transform=plt.gca().transAxes,
302297
)
303298
# Write method name in lower right
304299
plt.text(
305-
0.96, 0.1, 'gifsplanation',
306-
ha='right', va='bottom',
307-
transform=plt.gca().transAxes
300+
0.96,
301+
0.1,
302+
"gifsplanation",
303+
ha="right",
304+
va="bottom",
305+
transform=plt.gca().transAxes,
308306
)
309307

310308
plt.savefig(
311-
f'{temp_path}/image-{idx}.png',
312-
bbox_inches='tight',
309+
f"{temp_path}/image-{idx}.png",
310+
bbox_inches="tight",
313311
pad_inches=0,
314-
transparent=False
312+
transparent=False,
315313
)
316314
plt.close()
317315

318316
# Command for ffmpeg to generate an mp4
319-
cmd = f"{ffmpeg_path} -loglevel quiet -stats -y " \
320-
f"-i {temp_path}/image-%d.png " \
321-
f"-c:v libx264 -vf scale=-2:{imgs[0][0].shape[0]} " \
322-
f"-profile:v baseline -level 3.0 -pix_fmt yuv420p " \
323-
f"'{target_filename}.mp4'"
317+
cmd = (
318+
f"{ffmpeg_path} -loglevel quiet -stats -y "
319+
f"-i {temp_path}/image-%d.png "
320+
f"-c:v libx264 -vf scale=-2:{imgs[0][0].shape[0]} "
321+
f"-profile:v baseline -level 3.0 -pix_fmt yuv420p "
322+
f"'{target_filename}.mp4'"
323+
)
324324

325325
if verbose:
326326
print(cmd)
@@ -331,23 +331,23 @@ def generate_video(
331331
if show:
332332
# If we in a jupyter notebook then show the video.
333333
from IPython.core.display import Video
334+
334335
try:
335-
return Video(target_filename + ".mp4",
336-
html_attributes="controls loop autoplay muted",
337-
embed=True,
338-
)
336+
return Video(
337+
target_filename + ".mp4",
338+
html_attributes="controls loop autoplay muted",
339+
embed=True,
340+
)
339341
except TypeError:
340-
return Video(target_filename + ".mp4",
341-
embed=True
342-
)
342+
return Video(target_filename + ".mp4", embed=True)
343343
else:
344344
return target_filename + ".mp4"
345345

346346

347347
def full_frame(width=None, height=None):
348348
"""Setup matplotlib so we can write to the entire canvas"""
349349

350-
matplotlib.rcParams['savefig.pad_inches'] = 0
350+
matplotlib.rcParams["savefig.pad_inches"] = 0
351351
figsize = None if width is None else (width, height)
352352
plt.figure(figsize=figsize)
353353
ax = plt.axes([0, 0, 1, 1], frameon=False)

tests/attr/test_latent_shift.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

3-
import torch
43
import captum
4+
import torch
55
from tests.helpers.basic import BaseTest
66

77

@@ -23,7 +23,6 @@ def forward(self, x):
2323

2424

2525
class TinyAE(torch.nn.Module):
26-
2726
def __init__(self):
2827
super(TinyAE, self).__init__()
2928

@@ -55,5 +54,5 @@ def test_basic_setup(self):
5554
# Computes counterfactual for class 3.
5655
output = attr.attribute(x, target=3, lambda_sweep_steps=10)
5756

58-
assert 10 == len(output['generated_images'])
59-
assert (1, 100) == output['generated_images'][0].shape
57+
assert 10 == len(output["generated_images"])
58+
assert (1, 100) == output["generated_images"][0].shape

0 commit comments

Comments
 (0)