Skip to content

Commit 2f618ff

Browse files
committed
make mypy happy
1 parent 435bee8 commit 2f618ff

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

captum/attr/_core/latent_shift.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
from typing import Callable, Tuple
3+
from typing import Callable, Tuple, List
44

55
import os
66
import shutil
@@ -53,7 +53,7 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
5353
def attribute(
5454
self,
5555
inputs: Tensor,
56-
target: int = None,
56+
target: int,
5757
fix_range: Tuple = None,
5858
search_pred_diff: float = 0.8,
5959
search_step_size: float = 10.0,
@@ -197,21 +197,20 @@ def compute_shift(lambdax):
197197
lbound,
198198
rbound,
199199
np.abs((lbound - rbound) / lambda_sweep_steps)
200-
)
200+
).tolist()
201201

202202
preds = []
203203
generated_images = []
204204

205205
for lam in lambdas:
206206
x_lambdax, pred = compute_shift(lam)
207207
generated_images.append(x_lambdax.cpu().numpy())
208-
preds.append(pred)
208+
preds.append(float(pred))
209209

210210
params = {}
211211
params['generated_images'] = generated_images
212212
params['lambdas'] = lambdas
213213
params['preds'] = preds
214-
params['target'] = target
215214

216215
x_lambda0 = x_lambda0.detach().cpu().numpy()
217216
if heatmap_method == 'max':
@@ -276,9 +275,6 @@ def generate_video(
276275
return a video to show in a jupyter notebook.
277276
"""
278277

279-
if not target_filename:
280-
target_filename = f'video-{params["target"]}'
281-
282278
if os.path.exists(target_filename + ".mp4"):
283279
os.remove(target_filename + ".mp4")
284280

0 commit comments

Comments
 (0)