1
1
#!/usr/bin/env python3
2
2
3
- from typing import Callable , Tuple , List
4
-
5
3
import os
6
4
import shutil
7
- import torch
8
- import numpy as np
9
- from torch import Tensor
10
- from captum .attr ._utils .attribution import GradientAttribution
11
- from captum .log import log_usage
12
5
import subprocess
6
+ from typing import Callable , List , Tuple
7
+
13
8
import matplotlib
14
9
import matplotlib .pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+ from captum .attr ._utils .attribution import GradientAttribution
13
+ from captum .log import log_usage
14
+ from torch import Tensor
15
15
16
16
17
17
class LatentShift (GradientAttribution ):
@@ -46,8 +46,8 @@ def __init__(self, forward_func: Callable, autoencoder) -> None:
46
46
self .ae = autoencoder
47
47
48
48
# check if ae has encode and decode
49
- assert hasattr (self .ae , ' encode' )
50
- assert hasattr (self .ae , ' decode' )
49
+ assert hasattr (self .ae , " encode" )
50
+ assert hasattr (self .ae , " decode" )
51
51
52
52
@log_usage ()
53
53
def attribute (
@@ -60,7 +60,7 @@ def attribute(
60
60
search_max_steps : int = 3000 ,
61
61
search_max_pixel_diff : float = 5000.0 ,
62
62
lambda_sweep_steps : int = 10 ,
63
- heatmap_method : str = ' int' ,
63
+ heatmap_method : str = " int" ,
64
64
verbose : bool = True ,
65
65
) -> dict :
66
66
r"""
@@ -152,7 +152,7 @@ def compute_shift(lambdax):
152
152
pred1 = pred1 .detach ().cpu ().numpy ()
153
153
cache [lambdax ] = x_lambdax , pred1
154
154
if verbose :
155
- print (f' Shift: { lambdax } , Prediction: { pred1 } ' )
155
+ print (f" Shift: { lambdax } , Prediction: { pred1 } " )
156
156
return cache [lambdax ]
157
157
158
158
_ , initial_pred = compute_shift (0 )
@@ -190,13 +190,11 @@ def compute_shift(lambdax):
190
190
rbound = 0
191
191
192
192
if verbose :
193
- print (' Selected bounds: ' , lbound , rbound )
193
+ print (" Selected bounds: " , lbound , rbound )
194
194
195
195
# Sweep over the range of lambda values to create a sequence
196
196
lambdas = np .arange (
197
- lbound ,
198
- rbound ,
199
- np .abs ((lbound - rbound ) / lambda_sweep_steps )
197
+ lbound , rbound , np .abs ((lbound - rbound ) / lambda_sweep_steps )
200
198
).tolist ()
201
199
202
200
preds = []
@@ -208,39 +206,33 @@ def compute_shift(lambdax):
208
206
preds .append (float (pred ))
209
207
210
208
params = {}
211
- params [' generated_images' ] = generated_images
212
- params [' lambdas' ] = lambdas
213
- params [' preds' ] = preds
209
+ params [" generated_images" ] = generated_images
210
+ params [" lambdas" ] = lambdas
211
+ params [" preds" ] = preds
214
212
215
213
x_lambda0 = x_lambda0 .detach ().cpu ().numpy ()
216
- if heatmap_method == ' max' :
214
+ if heatmap_method == " max" :
217
215
# Max difference from lambda 0 frame
218
- heatmap = np .max (
219
- np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0
220
- )
216
+ heatmap = np .max (np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0 )
221
217
222
- elif heatmap_method == ' mean' :
218
+ elif heatmap_method == " mean" :
223
219
# Average difference between 0 and other lambda frames
224
- heatmap = np .mean (
225
- np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0
226
- )
220
+ heatmap = np .mean (np .abs (x_lambda0 [0 ][0 ] - generated_images [0 ][0 ]), 0 )
227
221
228
- elif heatmap_method == 'mm' :
222
+ elif heatmap_method == "mm" :
229
223
# Difference between first and last frames
230
- heatmap = np .abs (
231
- generated_images [0 ][0 ][0 ] - generated_images [- 1 ][0 ][0 ]
232
- )
224
+ heatmap = np .abs (generated_images [0 ][0 ][0 ] - generated_images [- 1 ][0 ][0 ])
233
225
234
- elif heatmap_method == ' int' :
226
+ elif heatmap_method == " int" :
235
227
# Average per frame differences
236
228
image_changes = []
237
229
for i in range (len (generated_images ) - 1 ):
238
- image_changes .append (np . abs (
239
- generated_images [i ][0 ][0 ] - generated_images [i + 1 ][0 ][0 ]
240
- ))
230
+ image_changes .append (
231
+ np . abs ( generated_images [i ][0 ][0 ] - generated_images [i + 1 ][0 ][0 ])
232
+ )
241
233
heatmap = np .mean (image_changes , 0 )
242
234
else :
243
- raise Exception (' Unknown heatmap_method for 2d image' )
235
+ raise Exception (" Unknown heatmap_method for 2d image" )
244
236
245
237
params ["heatmap" ] = heatmap
246
238
@@ -259,20 +251,20 @@ def generate_video(
259
251
):
260
252
"""Generate a video from the generated images.
261
253
262
- Args:
263
- params: The dict returned from the call to `attribute`.
264
- target_filename: The filename to write the video to. `.mp4` will
265
- be added to the end of the string.
266
- watermark: To add the probability output and the name of the
267
- method.
268
- ffmpeg_path: The path to call `ffmpeg`
269
- temp_path: A temp path to write images.
270
- show: To try and show the video in a jupyter notebook.
271
- verbose: True to print debug text
272
-
273
- Returns:
274
- The filename of the video if show=False, otherwise it will
275
- return a video to show in a jupyter notebook.
254
+ Args:
255
+ params: The dict returned from the call to `attribute`.
256
+ target_filename: The filename to write the video to. `.mp4` will
257
+ be added to the end of the string.
258
+ watermark: To add the probability output and the name of the
259
+ method.
260
+ ffmpeg_path: The path to call `ffmpeg`
261
+ temp_path: A temp path to write images.
262
+ show: To try and show the video in a jupyter notebook.
263
+ verbose: True to print debug text
264
+
265
+ Returns:
266
+ The filename of the video if show=False, otherwise it will
267
+ return a video to show in a jupyter notebook.
276
268
"""
277
269
278
270
if os .path .exists (target_filename + ".mp4" ):
@@ -285,42 +277,50 @@ def generate_video(
285
277
286
278
# Add reversed so we have an animation cycle
287
279
towrite = list (reversed (imgs )) + list (imgs )
288
- ys = list (reversed (params [' preds' ])) + list (params [' preds' ])
280
+ ys = list (reversed (params [" preds" ])) + list (params [" preds" ])
289
281
290
282
for idx , img in enumerate (towrite ):
291
283
292
- px = 1 / plt .rcParams [' figure.dpi' ]
284
+ px = 1 / plt .rcParams [" figure.dpi" ]
293
285
full_frame (img [0 ].shape [0 ] * px , img [0 ].shape [1 ] * px )
294
- plt .imshow (img [0 ], interpolation = ' none' )
286
+ plt .imshow (img [0 ], interpolation = " none" )
295
287
296
288
if watermark :
297
289
# Write prob output in upper left
298
290
plt .text (
299
- 0.05 , 0.95 , f"{ float (ys [idx ]):1.1f} " ,
300
- ha = 'left' , va = 'top' ,
301
- transform = plt .gca ().transAxes
291
+ 0.05 ,
292
+ 0.95 ,
293
+ f"{ float (ys [idx ]):1.1f} " ,
294
+ ha = "left" ,
295
+ va = "top" ,
296
+ transform = plt .gca ().transAxes ,
302
297
)
303
298
# Write method name in lower right
304
299
plt .text (
305
- 0.96 , 0.1 , 'gifsplanation' ,
306
- ha = 'right' , va = 'bottom' ,
307
- transform = plt .gca ().transAxes
300
+ 0.96 ,
301
+ 0.1 ,
302
+ "gifsplanation" ,
303
+ ha = "right" ,
304
+ va = "bottom" ,
305
+ transform = plt .gca ().transAxes ,
308
306
)
309
307
310
308
plt .savefig (
311
- f' { temp_path } /image-{ idx } .png' ,
312
- bbox_inches = ' tight' ,
309
+ f" { temp_path } /image-{ idx } .png" ,
310
+ bbox_inches = " tight" ,
313
311
pad_inches = 0 ,
314
- transparent = False
312
+ transparent = False ,
315
313
)
316
314
plt .close ()
317
315
318
316
# Command for ffmpeg to generate an mp4
319
- cmd = f"{ ffmpeg_path } -loglevel quiet -stats -y " \
320
- f"-i { temp_path } /image-%d.png " \
321
- f"-c:v libx264 -vf scale=-2:{ imgs [0 ][0 ].shape [0 ]} " \
322
- f"-profile:v baseline -level 3.0 -pix_fmt yuv420p " \
323
- f"'{ target_filename } .mp4'"
317
+ cmd = (
318
+ f"{ ffmpeg_path } -loglevel quiet -stats -y "
319
+ f"-i { temp_path } /image-%d.png "
320
+ f"-c:v libx264 -vf scale=-2:{ imgs [0 ][0 ].shape [0 ]} "
321
+ f"-profile:v baseline -level 3.0 -pix_fmt yuv420p "
322
+ f"'{ target_filename } .mp4'"
323
+ )
324
324
325
325
if verbose :
326
326
print (cmd )
@@ -331,23 +331,23 @@ def generate_video(
331
331
if show :
332
332
# If we in a jupyter notebook then show the video.
333
333
from IPython .core .display import Video
334
+
334
335
try :
335
- return Video (target_filename + ".mp4" ,
336
- html_attributes = "controls loop autoplay muted" ,
337
- embed = True ,
338
- )
336
+ return Video (
337
+ target_filename + ".mp4" ,
338
+ html_attributes = "controls loop autoplay muted" ,
339
+ embed = True ,
340
+ )
339
341
except TypeError :
340
- return Video (target_filename + ".mp4" ,
341
- embed = True
342
- )
342
+ return Video (target_filename + ".mp4" , embed = True )
343
343
else :
344
344
return target_filename + ".mp4"
345
345
346
346
347
347
def full_frame (width = None , height = None ):
348
348
"""Setup matplotlib so we can write to the entire canvas"""
349
349
350
- matplotlib .rcParams [' savefig.pad_inches' ] = 0
350
+ matplotlib .rcParams [" savefig.pad_inches" ] = 0
351
351
figsize = None if width is None else (width , height )
352
352
plt .figure (figsize = figsize )
353
353
ax = plt .axes ([0 , 0 , 1 , 1 ], frameon = False )
0 commit comments