|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -from typing import Callable, Tuple |
| 3 | +from typing import Callable, Tuple, List |
4 | 4 |
|
5 | 5 | import os
|
6 | 6 | import shutil
|
@@ -53,7 +53,7 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
|
53 | 53 | def attribute(
|
54 | 54 | self,
|
55 | 55 | inputs: Tensor,
|
56 |
| - target: int = None, |
| 56 | + target: int, |
57 | 57 | fix_range: Tuple = None,
|
58 | 58 | search_pred_diff: float = 0.8,
|
59 | 59 | search_step_size: float = 10.0,
|
@@ -197,21 +197,20 @@ def compute_shift(lambdax):
|
197 | 197 | lbound,
|
198 | 198 | rbound,
|
199 | 199 | np.abs((lbound - rbound) / lambda_sweep_steps)
|
200 |
| - ) |
| 200 | + ).tolist() |
201 | 201 |
|
202 | 202 | preds = []
|
203 | 203 | generated_images = []
|
204 | 204 |
|
205 | 205 | for lam in lambdas:
|
206 | 206 | x_lambdax, pred = compute_shift(lam)
|
207 | 207 | generated_images.append(x_lambdax.cpu().numpy())
|
208 |
| - preds.append(pred) |
| 208 | + preds.append(float(pred)) |
209 | 209 |
|
210 | 210 | params = {}
|
211 | 211 | params['generated_images'] = generated_images
|
212 | 212 | params['lambdas'] = lambdas
|
213 | 213 | params['preds'] = preds
|
214 |
| - params['target'] = target |
215 | 214 |
|
216 | 215 | x_lambda0 = x_lambda0.detach().cpu().numpy()
|
217 | 216 | if heatmap_method == 'max':
|
@@ -276,9 +275,6 @@ def generate_video(
|
276 | 275 | return a video to show in a jupyter notebook.
|
277 | 276 | """
|
278 | 277 |
|
279 |
| - if not target_filename: |
280 |
| - target_filename = f'video-{params["target"]}' |
281 |
| - |
282 | 278 | if os.path.exists(target_filename + ".mp4"):
|
283 | 279 | os.remove(target_filename + ".mp4")
|
284 | 280 |
|
|
0 commit comments